{ "cells": [ { "cell_type": "markdown", "id": "5fbc2d16-59f9-4be3-b93e-1a5440c7efd0", "metadata": {}, "source": [ "# Tutorial 6 - Physics-Informed Kolmogorov–Arnold Networks" ] }, { "cell_type": "markdown", "id": "1afe6a1e-3ab4-4f3f-ad47-f6cd66419504", "metadata": {}, "source": [ "One of the areas where KANs have found a lot of applications is PDE solving by replacing MLPs as the underlying architecture within the Physics-Informed Machine Learning (PIML) framework. This is why `jaxKAN` includes its own `pikan` module with several utilities relevant to Physics-Informed Kolmogorov–Arnold Networks (PIKANs), as this framework has come to be known." ] }, { "cell_type": "code", "execution_count": 1, "id": "0a2ef2a6-f681-427f-8252-ade2111ce0e6", "metadata": {}, "outputs": [], "source": [ "from jaxkan.models.KAN import KAN\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from jaxkan.pikan.pde import get_burgers_res\n", "from jaxkan.pikan.sampling import get_collocs_sobol\n", "\n", "from flax import nnx\n", "import optax\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import os\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" ] }, { "cell_type": "code", "execution_count": null, "id": "5d28ce1e-6475-4e12-90b2-27fa22b41cf7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "60e70a5d-0340-4bdc-af4e-150d0098a87e", "metadata": {}, "source": [ "## Data Generation" ] }, { "cell_type": "markdown", "id": "166ad90d-430e-45ca-86a8-e6e4dbc1c943", "metadata": {}, "source": [ "For the purposes of this example, we will be solving Burgers' Equation,\n", "\n", "$$ \\frac{\\partial u}{\\partial t} + u\\frac{\\partial u}{\\partial x} - \\nu \\frac{\\partial^2 u}{\\partial x^2} = 0,$$\n", "\n", "for $\\nu = \\pi/100$ in the $\\Omega = [0,1]\\times [-1, 1]$ domain, subject to the boundary conditions\n", "\n", "$$ u\\left(t=0, x\\right) = -\\sin\\left(\\pi x\\right), $$\n", "\n", "$$ u\\left(t, x=-1\\right) = u\\left(t, x=1\\right) = 0. $$\n", "\n", "To this end, we must first define appropriate collocation points." ] }, { "cell_type": "code", "execution_count": 2, "id": "b986e75a-6d4a-402f-bea7-36d13f4a7866", "metadata": {}, "outputs": [], "source": [ "seed = 42\n", "\n", "# Generate Collocation points for PDE\n", "pde_collocs = get_collocs_sobol(ranges=[(0,1), (-1,1)], total_points=2**12, seed=seed)\n", "\n", "# Generate Collocation points for IC\n", "ic_collocs = get_collocs_sobol(ranges=[(0,0), (-1,1)], total_points=2**6, seed=seed)\n", "ic_data = - jnp.sin(np.pi*ic_collocs[:,1]).reshape(-1,1)\n", "\n", "# Generate Collocation points for BCs\n", "bc1_collocs = get_collocs_sobol(ranges=[(0,1), (-1,-1)], total_points=2**6, seed=seed)\n", "bc1_data = jnp.zeros(bc1_collocs.shape[0]).reshape(-1,1)\n", "\n", "bc2_collocs = get_collocs_sobol(ranges=[(0,1), (1,1)], total_points=2**6, seed=seed)\n", "bc2_data = jnp.zeros(bc2_collocs.shape[0]).reshape(-1,1)\n", "\n", "# Concatenate IC/BCs\n", "bc_collocs = jnp.concatenate([ic_collocs, bc1_collocs, bc2_collocs], axis=0)\n", "bc_data = jnp.concatenate([ic_data, bc1_data, bc2_data], axis=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "004fb416-2223-40b2-9cc0-98551099f04c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "4e9d9bba-71e2-4d5b-95b1-36167fb70c1a", "metadata": {}, "source": [ "## KAN Model" ] }, { "cell_type": "markdown", "id": "2fd54c6c-3d28-4864-af38-3b22875d0f0a", "metadata": {}, "source": [ "We covered KAN Model selection in previous tutorials. For this example, we will be using a Chebychev KAN Layer." ] }, { "cell_type": "code", "execution_count": 3, "id": "b350b56f-5daa-411f-9090-b544760c34ef", "metadata": {}, "outputs": [], "source": [ "# Initialize a KAN model\n", "n_in = pde_collocs.shape[1]\n", "n_out = 1\n", "n_hidden = 6\n", "\n", "layer_dims = [n_in, n_hidden, n_hidden, n_out]\n", "req_params = {'D': 5, 'flavor': 'exact', 'residual': None, 'external_weights': False, 'init_scheme': {'type': 'glorot_fine'}, 'add_bias': True}\n", "\n", "model = KAN(layer_dims = layer_dims,\n", " layer_type = 'chebyshev',\n", " required_parameters = req_params,\n", " seed = seed\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "0ffb6338-c401-4e65-a528-5a575b6eafea", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "f681d135-c885-4e59-87d7-1ea16a157c50", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 4, "id": "6723a412-c313-453c-aa88-9274687ee54e", "metadata": {}, "outputs": [], "source": [ "opt_type = optax.adam(learning_rate=0.001)\n", "\n", "optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)" ] }, { "cell_type": "markdown", "id": "eaac2ddc-ad53-40db-804b-7f54f3000476", "metadata": {}, "source": [ "This problem does not correspond to supervised training. We simply need to define a loss term that enforces the PDE as well as a loss term that enforces the boundary conditions." ] }, { "cell_type": "code", "execution_count": 5, "id": "f18ab849-3c05-418a-a30b-3928f77c332d", "metadata": {}, "outputs": [], "source": [ "# PDE Residual\n", "burgers_res = get_burgers_res()\n", "\n", "# Define train loop\n", "@nnx.jit\n", "def train_step(model, optimizer, pde_collocs, bc_collocs, bc_data):\n", "\n", " def loss_fn(model):\n", " # PDE part\n", " pde_res = burgers_res(model, pde_collocs)\n", " total_loss = jnp.mean(pde_res**2)\n", "\n", " # IC/BC part\n", " bc_res = model(bc_collocs) - bc_data\n", " total_loss += jnp.mean(bc_res**2)\n", "\n", " return total_loss\n", " \n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", " optimizer.update(model, grads)\n", " \n", " return loss" ] }, { "cell_type": "code", "execution_count": 6, "id": "b06282a5-8899-4801-9c9e-d8b73b55d74a", "metadata": {}, "outputs": [], "source": [ "# Initialize train_losses\n", "num_epochs = 5000\n", "train_losses = jnp.zeros((num_epochs,))\n", "\n", "for epoch in range(num_epochs):\n", " # Calculate the loss\n", " loss = train_step(model, optimizer, pde_collocs, bc_collocs, bc_data)\n", " \n", " # Append the loss\n", " train_losses = train_losses.at[epoch].set(loss)" ] }, { "cell_type": "code", "execution_count": null, "id": "05e14e4d-bb71-464f-bd44-27d4a7d3aab5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "faab949e-dadb-4cec-bc13-63b10cb609c5", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 7, "id": "86630d29-9b9f-452a-b2f1-399b02768829", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnAAAAGJCAYAAAAKZg7vAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAfORJREFUeJzt3Xl4U1X6B/BvliZp06bpvtCWlpYutBToQgWXwaEj4grqiMggiyPjgqM/dBwdRwVHh3EZhsHBZVTEDXBFHRUUUVyhK6V0p6UblG50Tds0TXJ+f7S5NHRLl/Se0vfzPH2e5uYmOb3fe8PLvfecI2GMMRBCCCGEkAlDKnYDCCGEEELI8FABRwghhBAywVABRwghhBAywVABRwghhBAywVABRwghhBAywVABRwghhBAywVABRwghhBAywVABRwghhBAywVABRwghhBAywVABRwixsnr1agQHB4/otRs3boREIhnbBhEyBMt+V19fL3ZTCBk3VMARMkFIJBKbfg4dOiR2U0WxevVqODs7i90MmzDG8Pbbb+Oyyy6DVquFk5MTZs6ciSeffBJtbW1iN68PS4E00E91dbXYTSRk0pGL3QBCiG3efvttq8dvvfUWDhw40Gd5VFTUqD7n1VdfhdlsHtFr//rXv+Lhhx8e1edf6EwmE2699Va8//77uPTSS7Fx40Y4OTnhxx9/xKZNm/DBBx/gm2++gY+Pj9hN7eOll17qt0jWarXj3xhCJjkq4AiZIH73u99ZPT5y5AgOHDjQZ/n52tvb4eTkZPPnODg4jKh9ACCXyyGX09fKYJ599lm8//77ePDBB/Hcc88Jy9etW4ebb74ZS5YswerVq7Fv375xbZct+8lNN90ET0/PcWoRIWQwdAmVkAvIggULEBMTg4yMDFx22WVwcnLCX/7yFwDAp59+iquvvhr+/v5QKpUIDQ3F3/72N5hMJqv3OP8euLKyMkgkEjz//PP473//i9DQUCiVSiQmJiItLc3qtf3dAyeRSLB+/Xp88skniImJgVKpRHR0NPbv39+n/YcOHUJCQgJUKhVCQ0PxyiuvjPl9dR988AHi4+Ph6OgIT09P/O53v8Pp06et1qmursaaNWsQEBAApVIJPz8/XH/99SgrKxPWSU9Px6JFi+Dp6QlHR0eEhIRg7dq1g352R0cHnnvuOYSHh2Pz5s19nr/22muxatUq7N+/H0eOHAEAXHPNNZg2bVq/7zdv3jwkJCRYLXvnnXeEv8/d3R233HILKisrrdYZbD8ZjUOHDkEikeC9997DX/7yF/j6+kKtVuO6667r0wbAtiwAoKCgADfffDO8vLzg6OiIiIgIPProo33Wa2pqwurVq6HVauHq6oo1a9agvb3dap0DBw7gkksugVarhbOzMyIiIsbkbydkvNF/lQm5wJw9exaLFy/GLbfcgt/97nfCpbidO3fC2dkZGzZsgLOzM7799ls8/vjjaGlpsToTNJBdu3ahtbUVf/jDHyCRSPDss8/ihhtuwMmTJ4c8a/fTTz/h448/xt133w0XFxds27YNN954IyoqKuDh4QEAOHr0KK688kr4+flh06ZNMJlMePLJJ+Hl5TX6jdJj586dWLNmDRITE7F582bU1NTg3//+N37++WccPXpUuBR44403Ijc3F/feey+Cg4NRW1uLAwcOoKKiQnh8xRVXwMvLCw8//DC0Wi3Kysrw8ccfD7kdGhsbcd999w14pvK2227DG2+8gc8//xwXXXQRli1bhttuuw1paWlITEwU1isvL8eRI0essnv66afx2GOP4eabb8bvf/971NXV4YUXXsBll11m9fcBA+8ng2loaOizTC6X97mE+vTTT0MikeDPf/4zamtrsXXrViQnJyMrKwuOjo4AbM8iOzsbl156KRwcHLBu3ToEBwejpKQE//vf//D0009bfe7NN9+MkJAQbN68GZmZmXjttdfg7e2NZ555BgCQm5uLa665BrGxsXjyySehVCpRXFyMn3/+eci/nRDuMELIhHTPPfew8w/hX/3qVwwAe/nll/us397e3mfZH/7wB+bk5MT0er2wbNWqVWzq1KnC49LSUgaAeXh4sIaGBmH5p59+ygCw//3vf8KyJ554ok+bADCFQsGKi4uFZceOHWMA2AsvvCAsu/baa5mTkxM7ffq0sOzEiRNMLpf3ec/+rFq1iqnV6gGfNxgMzNvbm8XExLCOjg5h+eeff84AsMcff5wxxlhjYyMDwJ577rkB32vv3r0MAEtLSxuyXb1t3bqVAWB79+4dcJ2GhgYGgN1www2MMcaam5uZUqlkDzzwgNV6zz77LJNIJKy8vJwxxlhZWRmTyWTs6aeftlrv+PHjTC6XWy0fbD/pjyXX/n4iIiKE9b777jsGgE2ZMoW1tLQIy99//30GgP373/9mjNmeBWOMXXbZZczFxUX4Oy3MZnOf9q1du9ZqnaVLlzIPDw/h8b/+9S8GgNXV1dn0dxPCM7qESsgFRqlUYs2aNX2WW858AEBrayvq6+tx6aWXor29HQUFBUO+77Jly+Dm5iY8vvTSSwEAJ0+eHPK1ycnJCA0NFR7HxsZCo9EIrzWZTPjmm2+wZMkS+Pv7C+uFhYVh8eLFQ76/LdLT01FbW4u7774bKpVKWH711VcjMjISX3zxBYDu7aRQKHDo0CE0Njb2+16Ws0Off/45urq6bG5Da2srAMDFxWXAdSzPtbS0AAA0Gg0WL16M999/H4wxYb333nsPF110EYKCggAAH3/8McxmM26++WbU19cLP76+vpg+fTq+++47q88ZaD8ZzEcffYQDBw5Y/bzxxht91rvtttus/sabbroJfn5++PLLLwHYnkVdXR1++OEHrF27Vvg7Lfq7rH7nnXdaPb700ktx9uxZYVtacvv0009H3FGHEF5QAUfIBWbKlClQKBR9lufm5mLp0qVwdXWFRqOBl5eX0AGiubl5yPc9/x9QSzE3UJEz2Gstr7e8tra2Fh0dHQgLC+uzXn/LRqK8vBwAEBER0ee5yMhI4XmlUolnnnkG+/btg4+PDy677DI8++yzVkNl/OpXv8KNN96ITZs2wdPTE9dffz3eeOMNdHZ2DtoGS1FjKeT601+Rt2zZMlRWVuLw4cMAgJKSEmRkZGDZsmXCOidOnABjDNOnT4eXl5fVT35+Pmpra60+Z6D9ZDCXXXYZkpOTrX7mzZvXZ73p06dbPZZIJAgLCxPuIbQ1C0uBHxMTY1P7htpHly1bhosvvhi///3v4ePjg1tuuQXvv/8+FXNkQqICjpALTO8zbRZNTU341a9+hWPHjuHJJ5/E//73Pxw4cEC4N8iWf8BkMlm/y3ufFbLHa8Vw//33o6ioCJs3b4ZKpcJjjz2GqKgoHD16FEB3QfLhhx/i8OHDWL9+PU6fPo21a9ciPj4eOp1uwPe1DPGSnZ094DqW52bMmCEsu/baa+Hk5IT3338fAPD+++9DKpXit7/9rbCO2WyGRCLB/v37+5wlO3DgAF555RWrz+lvP5nohtrPHB0d8cMPP+Cbb77BypUrkZ2djWXLluE3v/lNn848hPCOCjhCJoFDhw7h7Nmz2LlzJ+677z5cc801SE5OtrokKiZvb2+oVCoUFxf3ea6/ZSMxdepUAEBhYWGf5woLC4XnLUJDQ/HAAw/g66+/Rk5ODgwGA/75z39arXPRRRfh6aefRnp6Ot59913k5uZiz549A7bB0vtx165dAxYMb731FoDu3qcWarUa11xzDT744AOYzWa89957uPTSS60uN4eGhoIxhpCQkD5nyZKTk3HRRRcNsYXGzokTJ6weM8ZQXFws9G62NQtL79ucnJwxa5tUKsXChQuxZcsW5OXl4emnn8a3337b5xIzIbyjAo6QScByZqL3GS+DwYAXX3xRrCZZkclkSE5OxieffIKqqipheXFx8ZiNh5aQkABvb2+8/PLLVpc69+3bh/z8fFx99dUAusdD0+v1Vq8NDQ2Fi4uL8LrGxsY+Zw9nz54NAINeRnVycsKDDz6IwsLCfofB+OKLL7Bz504sWrSoT8G1bNkyVFVV4bXXXsOxY8esLp8CwA033ACZTIZNmzb1aRtjDGfPnh2wXWPtrbfesrpM/OGHH+LMmTPC/Yy2ZuHl5YXLLrsMO3bsQEVFhdVnjOTsbX+9aG3JjRAe0TAihEwC8+fPh5ubG1atWoU//vGPkEgkePvtt7m6hLlx40Z8/fXXuPjii3HXXXfBZDLhP//5D2JiYpCVlWXTe3R1deGpp57qs9zd3R133303nnnmGaxZswa/+tWvsHz5cmHoiuDgYPzf//0fAKCoqAgLFy7EzTffjBkzZkAul2Pv3r2oqanBLbfcAgB488038eKLL2Lp0qUIDQ1Fa2srXn31VWg0Glx11VWDtvHhhx/G0aNH8cwzz+Dw4cO48cYb4ejoiJ9++gnvvPMOoqKi8Oabb/Z53VVXXQUXFxc8+OCDkMlkuPHGG62eDw0NxVNPPYVHHnkEZWVlWLJkCVxcXFBaWoq9e/di3bp1ePDBB23ajgP58MMP+52J4Te/+Y3VMCTu7u645JJLsGbNGtTU1GDr1q0ICwvDHXfcAaB7sGhbsgCAbdu24ZJLLkFcXBzWrVuHkJAQlJWV4YsvvrB5v7B48skn8cMPP+Dqq6/G1KlTUVtbixdffBEBAQG45JJLRrZRCBGLKH1fCSGjNtAwItHR0f2u//PPP7OLLrqIOTo6Mn9/f/bQQw+xr776igFg3333nbDeQMOI9DesBgD2xBNPCI8HGkbknnvu6fPaqVOnslWrVlktO3jwIJszZw5TKBQsNDSUvfbaa+yBBx5gKpVqgK1wzqpVqwYc6iI0NFRY77333mNz5sxhSqWSubu7sxUrVrBTp04Jz9fX17N77rmHRUZGMrVazVxdXVlSUhJ7//33hXUyMzPZ8uXLWVBQEFMqlczb25tdc801LD09fch2MsaYyWRib7zxBrv44ouZRqNhKpWKRUdHs02bNjGdTjfg61asWMEAsOTk5AHX+eijj9gll1zC1Go1U6vVLDIykt1zzz2ssLBQWGew/aQ/gw0j0nv/sQwjsnv3bvbII48wb29v5ujoyK6++uo+w4AwNnQWFjk5OWzp0qVMq9UylUrFIiIi2GOPPdanfecPD/LGG28wAKy0tJQx1r1/XX/99czf358pFArm7+/Pli9fzoqKimzeFoTwQsIYR/8FJ4SQ8yxZsgS5ubl97qsi/Dl06BAuv/xyfPDBB7jpppvEbg4hFzS6B44Qwo2Ojg6rxydOnMCXX36JBQsWiNMgQgjhFN0DRwjhxrRp07B69WpMmzYN5eXleOmll6BQKPDQQw+J3TRCCOEKFXCEEG5ceeWV2L17N6qrq6FUKjFv3jz8/e9/7zMwLCGETHZ0DxwhhBBCyARD98ARQgghhEwwVMARQgghhEwwdA/cIMxmM6qqquDi4gKJRCJ2cwghhBBygWOMobW1Ff7+/pBKBz7PRgXcIKqqqhAYGCh2MwghhBAyyVRWViIgIGDA56mAG4SLiwuA7o2o0Wjs9jnp6elISEiw2/sT21EWfKAc+EFZ8IOy4IO9c2hpaUFgYKBQgwyECrhBWC6bajQauxZwarXaru9PbEdZ8IFy4AdlwQ/Kgg/jlcNQt27RMCKDaGlpgaurK5qbm+0aFmOM7rHjBGXBB8qBH5QFPygLPtg7B1trD+qFyoGsrCyxm0B6UBZ8oBz4QVnwg7LgAy85UAHHAYPBIHYTSA/Kgg+UAz8oC35QFnzgJQe6B44DWq1W7CaQHpQFHygHflAWfGCMQaPRQK/Xi92USW+0OchkMsjl8lFfhqUCjgODdRMm44uy4APlwA/KQnwGgwFnzpyB2WxGaWmp2M2Z9MYiBycnJ/j5+UGhUIz4PaiA40BOTg6SkpLEbgYBZcELyoEflIW4LMWCTCaDp6cnNBoNdWQQWVtbG9Rq9YheyxiDwWBAXV0dSktLMX369EEH6x0MFXCEEEIIpwwGA8xmMwIDA2E2m+Ho6Ch2kyY9o9EIlUo14tc7OjrCwcEB5eXlMBgMI34vKuA4MG3aNLGbQHpQFnygHPhBWfBBKpXCwcFB7GYQAEqlctTvMdKzblbvMep3IKNGN6Xyg7LgA+XAD8qCHzRsKx94yYEKOJF9d7QCn3yfD10HH92SJ7uqqiqxm0BAOfCEsuAHL8NXTHa85EAFnMju3/YdXviyEq9/kSN2UwghhBBuBQcHY+vWrWI3gxtUwIms02AEAHzwbYHILSEAEB8fL3YTCCgHnlAW/Bhpz8fxJpFIBv3ZuHHjiN43LS0N69atG1XbFixYgPvvv39U78FLDlTAicxJ1X1T6g0LwkVuCQGA3NxcsZtAQDnwhLLgR0dHh9hNsMmZM2eEn61bt0Kj0Vgte/DBB4V1GWMwGo02va+XlxecnJzs1Wyb8ZIDFXAic3bsHsRvUWKwuA0hAOiGbV5QDvygLPhhNpvBGEO7vkuUH1tv3vf19RV+XF1dIZFIhMcFBQVwcXHBvn37EB8fD6VSiZ9++gklJSW4/vrr4ePjA2dnZyQmJuKbb76xet/zL6FKJBK89tprWLp0KZycnDB9+nR89tlno9rGH330EaKjo6FUKhEcHIx//vOfVs+/+OKLiI2NhUqlgo+PD2666SbhuQ8//BAzZ86Eo6MjPDw8kJycjLa2tlG1ZzA0jIjIZNLuARmNJj56tUx2Go1G7CYQUA48oSz4IZPJ0NFpxOy1b4ny+Vk7bhOuGo3Www8/jOeffx7Tpk2Dm5sbKisrcdVVV+Hpp5+GUqnEW2+9hWuvvRaFhYUICgoa8H02bdqEZ599Fs899xxeeOEFrFixAuXl5XB3dx92mzIyMnDzzTdj48aNWLZsGX755Rfcfffd8PDwwOrVq5Geno4//vGPeP3117FgwQI0NDTgxx9/BNB91nH58uV49tlnsXTpUrS2tuLHH3+0a49VKuBEJu0p4Mxms8gtIUD3//CI+CgHflAW/FAqldAbTGI3Y0w8+eST+M1vfiM8dnd3x6xZs4THf/vb37B371589tlnWL9+/YDvs3r1aixfvhwA8Pe//x3btm1DamoqrrzyymG3acuWLVi4cCEee+wxAEB4eDjy8vLw3HPPYfXq1aioqIBarcaSJUvg6uqKqVOnYs6cOQC6Czij0YgbbrgBU6dOBQDMnDlz2G0YDirgRGY5A2cy0xk4HmRnZ9O0QRygHPhBWfCjvb0darUaWTtuG9HrP/iuEDu+zMHaq2Lw28sjhv16R+XYlQwJCQlWj3U6HTZu3IgvvvhCKIY6OjpQUVEx6PvExsYKv6vVamg0GtTW1o6oTfn5+bj++uutll188cXYunUrTCYTfvOb32Dq1KkIDQ3F4sWLceWVVwqXb2fNmoWFCxdi5syZWLRoEa644grcdNNNcHNzG1FbbEEFnMhksu7bEKmAI4QQMhSJRDLiy5irFsdg1eKYMW7RyJzfk/PBBx/EgQMH8PzzzyMsLAyOjo646aabhhxz7fzZKSQSid2uaLm4uCAzMxP79u3Djz/+iMcffxwbN25EWloatFotDhw4gF9++QVff/01XnjhBTz66KNISUlBSEiIXdpDnRhERmfg+GI59U3ERTnwg7Lgx1hM4cSrn3/+GatXr8bSpUsxc+ZM+Pr6oqysbFzbEBUVhZ9//rlPu8LDwyGTyQAAcrkcV155JZ599llkZ2ejrKwM3377LYDu4vHiiy/Gpk2bcPToUSgUCuzdu9du7aUzcCKje+D4YjJdGPeXTHSUAz8oC37wMoWTPUyfPh0ff/wxrr32WkgkEjz22GN2+3exrq4OWVlZVsv8/PzwwAMPIDExEX/729+wbNkyHD58GP/5z3/w4osvAgA+//xznDx5EvPmzYO3tze+/PJLmM1mREREICUlBQcPHsQVV1wBb29vpKSkoK6uDlFRUXb5GwA6Aye65tZOAMDBjMGv85PxcerUKbGbQEA58ISy4AcvUzjZw5YtW+Dm5ob58+fj2muvxaJFixAXF2eXz9q1axfmzJlj9fPqq68iLi4O77//Pvbs2YOYmBg8/vjjePLJJ7F69WoAgFarxccff4xFixYhKioKL7/8Mnbv3o3o6GhoNBr88MMPuOqqqxAeHo6//vWv+Oc//4nFixfb5W8AAAm7kEv6UWppaYGrqyuam5vt1pU++rY30GU0w12jwpGXV9jlM4jtUlJS6IZtDlAO/KAsxKXX61FaWoqQkBAYjUY4OzuL3aRJT6fTjTqH3rmqVCqr52ytPegSaj+2b9+O7du3C5cO0tPToVarERcXh/z8fHR0dMDFxQUhISHIzs4G0H2fiNlsRmVlJQBg9uzZKC4uhk6ng1qtRnh4OI4ePQoACAgIgEwmQ3l5ObROMtS1mJEY5oqUlBSoVCpER0cjIyMDAODv7w+VSoWTJ08CAGJiYnDq1Ck0NTVBoVBg9uzZSE1NBdA9eKKzszOKi4sBdF/Pr6mpQUNDA+RyOeLj45GamgrGGLy8vODm5oaioiIAQEREBBoaGlBXVwepVIrExESkp6fDZDLBw8MD3t7eyM/PB9B9qrulpQU1NTUAgKSkJGRmZqKrqwtubm7w9/cXRm8PDQ1Fe3s7zpw5A6C751FOTg70ej1cXV0RFBSE48ePA+gersBoNAr/44+Li0NBQQHa29vh7OyM0NBQHDt2DACEcYEsPZRmzZqFkpIS6HQ6ODk5ITIyEpmZmcL2lsvlwv0UM2fOREVFBZqbm6FSqRATE4P09HQAgLe3N+rr61FSUgIAiI6ORlVVFRobG+Hg4IC4uDikpKQAAHx8fKDRaHDixAlhe9fW1uLs2bOQyWRISEhAWloazGYzvLy84O7ujsLCQgDd3dMbGxtRV1cHiUSCuXPnIiMjA0ajEe7u7vDx8RG2d1hYGHQ6HaqrqwEAc+fORVZWFgwGA7RaLQICApCT0z2X7rRp06DX64UJyOPj45Gbmwu9Xg+NRoPg4GCrfdZkMgnbe86cOSgqKkJbWxucnZ0RFhYmXGYIDAyEVCpFeXk5gO6eX6WlpWhtbYWjoyOioqKE7T1lyhQoFAqUlpYK27uyshJNTU1QKpWIjY1FWlqasM+q1Wphe8+YMQPV1dUwm83IzMy02t7e3t5wdXUVtndkZCTq6+tRX18v7LOW7e3p6QlPT08UFBQI+2xzc7PQO633Puvu7g5fX1/k5eUJ+2xbW5uwvRMTE5GdnY3Ozk5otVoEBgYK+2xISAgMBgNOnz4t7LNj/R1h2d5lZWVoaWkZ9+8Is9mMlJQU+o7o+Y7w8/ODk5PTuH5HWO59k0gk0Ol0kMvlcHBwEGYEUKlUMJlM6OrqAtDdQaCjowNms7nPukqlEowx4Wxe73VlMhmUSiXa29v7XdfJyQmdnZ0wmUx91lUoFJBIJOjs7OyzrlQqhUqlGnBdR0dHdHV1wWg0QiqVwtHRURgA18HBAVKptN91JRIJ1Go1dDqdsK5MJhMGn1apVDAajVbrtrW1gTEGuVwOuVxutW7vbejs7Gy17vnbsLOzc0Tb28nJCXq9Hh0dHejq6oJerxf2Wct3hOV7ayh0Bm4Q43EGbvmmz5FRWIMX7v81Fs21T08VYrucnBzExPDRS2syoxz4QVmIq/eZGsYYHB0dxW7SpNfR0THqHMbiDBzdAycyOQ0jwhV7TntCbEc58IOy4Ad1KOEDLzlQAScyqaRnGBGaSosLdH8JHygHflAW/LAMZUHExUsOVMCJ7Nw4cDSMCA/CwsLEbgIB5cATyoIPjLELehy4iWQschiLu9eogBPZuXHg6AwcD84fG4iIg3LgB2UhLstMA+3t7UInACKuscjB8h7nzyQxHNQLVWSWe+CMVMARQgg5j0wmg1arRW1tLTQaDWQyGSQ9t94QcXR2dkIuH1n5xBhDe3s7amtrodVqR3U5lgo4kVXVd3d/Tsmtws0jmFyYjK3AwECxm0BAOfCEshCfr68vAKC+vh4tLS0it4ZYhlEZDa1WK+Q6UlTAiaysuhkA8MMxGu2cB1Ip3VXAA8qBH5SF+CQSCfz8/MAYg7u7u9jNmfTq6urg5eU14tdbxqsbLSrgRBY6RYvc0rOYP3OK2E0hAMrLy0f9vyIyepQDPygLflRWVsLf31/sZkx6VVVVXJyZpv9aiWyanxYAMDvMW9yGEEIIIWTCoAJOZNQLlS+xsbFiN4GAcuAJZcEPyoIPvORABZzIaBw4vljm7yTiohz4QVnwg7LgAy85UAEnMpnMUsDRGTgetLa2it0EAsqBJ5QFPygLPvCSAxVwIiut6u6FmnWiVuSWEAA0UTQnKAd+UBb8oCz4wEsOVMCJLK/8LAAgvaBa5JYQAIiKihK7CQSUA08oC35QFnzgJQcq4EQ2c5onAGBOOPVC5UFmZqbYTSCgHHhCWfCDsuADLzlQASeyqKkeAICIIA+RW0IIIYSQiYIKOJHJekY53/1NPnZ/ky9ya8iUKTSgMg8oB35QFvygLPjASw5UwInMMpm9rqMLr3yWLXJriEKhELsJBJQDTygLflAWfOAlByrgROak6p7NzEkpxx+u42NwwMmMl/F9JjvKgR+UBT8oCz7wkgMVcCJTqxwAAGZG48ARQgghxDZUwInMUsDpDSZseS9D5NaQmTNnit0EAsqBJ5QFPygLPvCSAxVwInNRn7uW3tllErElBAAqKyvFbgIB5cATyoIflAUfeMmBCjiRTfF0FrsJpJempiaxm0BAOfCEsuAHZcEHXnKgAk5k/lYFHN0HJzalUil2EwgoB55QFvygLPjASw5UwInMVd1rR6D6TXSxsdQTmAeUAz8oC35QFnzgJQcq4EQmlUqE3x0cZCK2hABAWlqa2E0goBx4Qlnwg7LgAy85UAHHAXlPETd7upfILSGEEELIREAFHAdYz7XT/LIGkVtCfH19xW4CAeXAE8qCH5QFH3jJgQo4DrhrVACAxUkhIreEqNVqsZtAQDnwhLLgB2XBB15yoAKOA1LWPf4bjQMnvpKSErGbQEA58ISy4AdlwQdecqACjgMNOiMA4Ou0MnEbQgghhJAJgQo4DrhrHAEAU7xoUF+xzZgxQ+wmEFAOPKEs+EFZ8IGXHKiA40CzTg8AqKhpFbklpLq6WuwmEFAOPKEs+EFZ8IGXHKiA48A0n+4zcIYuE3Z/ky9yaya3hgbqCcwDyoEflAU/KAs+8JIDFXAciAjovnTa2WXCK59li9yayc3BwUHsJhBQDjyhLPhBWfCBlxyogONAaHCg8HtcuLeILSFxcXFiN4GAcuAJZcEPyoIPvORABRwH6qpPC79nFtWK2BKSkpIidhMIKAeeUBb8oCz4wEsOVMBxQKU4FwOdgSOEEELIUKiA44CPp7vwO52BE5e3NxXQPKAc+EFZ8IOy4AMvOVABxwFPd43wO52BE5erq6vYTSCgHHhCWfCDsuADLzlQAceBs7VnhN/pDJy4Tpw4IXYTCCgHnlAW/KAs+MBLDlTAcYDugSOEEELIcFABx4GYqHDhdzoDJ67IyEixm0BAOfCEsuAHZcEHXnKgAo4DHbpm4ffZYV4itoTU19eL3QQCyoEnlAU/KAs+8JIDFXAcaGttEn6nM3Di4uXAnOwoB35QFvygLPjASw5UwHFA4SCDRNL9+4xg98FXJnYlldIhwQPKgR+UBT8oCz7wkoOEMcbEbgSvWlpa4OrqiubmZmg0mqFfMAoRK14HY4C31hE/vXirXT+LEEIIIXyytfbgo4yc5NLS0qBSyAEA0wPdRG7N5JaWliZ2EwgoB55QFvygLPjASw5UwHHAbDajy2gCABSUN4jcmsnNbDaL3QQCyoEnlAU/KAs+8JIDFXAc8PT0hIujAgAQ7MfHCM+Tlaenp9hNIKAceEJZ8IOy4AMvOVABxwFPT0+0dxoBACWnm8RtzCTHy4E52VEO/KAs+EFZ8IGXHCZFAbd06VK4ubnhpptuErsp/SooKIC7RgUA8PdUi9yaya2goEDsJhBQDjyhLPhBWfCBlxwmRQF333334a233hK7GYOqb+4AAJScbh5iTUIIIYRMdpOigFuwYAFcXFzEbsaApk+fDmnPQHA0pou4pk+fLnYTCCgHnlAW/KAs+MBLDqIXcD/88AOuvfZa+Pv7QyKR4JNPPumzzvbt2xEcHAyVSoWkpCSkpqaOf0PtqLm5GRfP9O9+wBh2f5MvboMmseZmOgPKA8qBH5QFPygLPvCSg+gFXFtbG2bNmoXt27f3+/x7772HDRs24IknnkBmZiZmzZqFRYsWobb23JRTs2fPRkxMTJ+fqqqq8fozRqW2thbzZ04BABiMZrzyWbbILZq8eu9XRDyUAz8oC35QFnzgJQe52A1YvHgxFi9ePODzW7ZswR133IE1a9YAAF5++WV88cUX2LFjBx5++GEAQFZW1pi0pbOzE52dncLjlpaWMXlfW1iGEQGAuHDvcftcQgghhEw8ohdwgzEYDMjIyMAjjzwiLJNKpUhOTsbhw4fH/PM2b96MTZs29Vmenp4OtVqNuLg45Ofno6OjAy4uLggJCUF2dvfZsqlTp8JsNqOyshJA91nB4uJi6HQ6qNVqhIeH4+jRowCAgIAAyGQylJeXAwBiY2OR/1Wm8HmZhTVISUkBAPj7+0OlUuHkyZMAgJiYGJw6dQpNTU1QKBSYPXu2cEnZ19cXzs7OKC4uBgBERUWhpqYGDQ0NkMvliI+PR2pqKhhj8PLygpubG4qKigAAERERaGhoQF1dHaRSKRITE5Geng6TyQQPDw94e3sjP7/70u706dPR0tKCmpoaAEBSUhIyMzPR1dUFNzc3+Pv7Izc3FwAQGhqK9vZ2nDlzBgCQkJCAnJwc6PV6uLq6IigoCMePHwcABAcHw2g04tSpUwCAuLg4FBQUoL29Hc7OzggNDcWxY8cAAEFBQQCAiooKAMCsWbNQUlICnU4HJycnREZGIjMzU9jecrkcZWVlAICZM2eioqICzc3NUKlUiImJQXp6OgDAz88P9fX1KCkpAQBER0ejqqoKjY2NcHBwQFxcnJCNj48PNBoNTpw4IWzv2tpanD17FjKZDAkJCUhLS4PZbIaXlxfc3d1RWFgIAAgPD0djYyPq6uogkUgwd+5cZGRkwGg0wt3dHT4+PsL2DgsLg06nQ3V1NQBg7ty5yMrKgsFggFarRUBAAHJycgAA06ZNg16vF84+x8fHIzc3F3q9HhqNBsHBwVb7rMlkErb3nDlzUFRUhLa2Njg7OyMsLEz4z1FgYCCkUqnVPltaWorW1lY4OjoiKipK2N5TpkyBQqFAaWmpsL0rKyvR1NQEpVKJ2NhYYSRzX19fqNVqYXvPmDFD+DszMzOttre3tzdcXV2F7R0ZGYn6+nrU19cL+6xle3t6esLT01PoLTZ9+nQ0NzcL/3Puvc+6u7vD19cXeXl5wj7b1tYmtCMxMRHZ2dno7OyEVqtFYGCgsM+GhITAYDDg9OnTwj5rr++IsrIytLS0QKVSITo6GhkZGQDs/x0BACkpKfQd0es7wsnJSZTvCLlcjpSUFPqOqK5GQ0NDn+09nt8R5eXldvuOsLWXK1dzoUokEuzduxdLliwBAFRVVWHKlCn45ZdfMG/ePGG9hx56CN9//70Q2lCSk5Nx7NgxtLW1wd3dHR988IHV+1n0dwYuMDDQ7nOhZmZmokvlh5VPfQkAuGb+NGxZf7ndPo8MzFI0EHFRDvygLPhBWfDB3jnYOhcq12fgxso333xj03pKpRJKpdLOremrq6sLLu4OwuPMIj6ur09GXV1dYjeBgHLgCWXBD8qCD7zkwHUB5+npCZlMJpyGt6ipqYGvr69IrRp77u7uUNI9cFxwd3cXuwkElANPKAt+UBZ84CUH0XuhDkahUCA+Ph4HDx4UlpnNZhw8eLDfS6ATla+vL5x7FXA/HjslYmsmtwvpPwYTGeXAD8qCH5QFH3jJQfQCTqfTISsrS7gRsrS0FFlZWcKNpxs2bMCrr76KN998E/n5+bjrrrvQ1tYm9Eq9EOTl5cHF6VwB12kwidiayc1ykyoRF+XAD8qCH5QFH3jJQfRLqOnp6bj88nM37G/YsAEAsGrVKuzcuRPLli1DXV0dHn/8cVRXV2P27NnYv38/fHx8xGqyXTjIz9XSjOZjIIQQQsggRC/gFixYgKE6wq5fvx7r168fpxaNv9DQUKvHBqNZpJaQ87Mg4qAc+EFZ8IOy4AMvOYh+CZV0z0bRG2PAgj/uEak1k9v5WRBxUA78oCz4QVnwgZccRD8Dx6Pt27dj+/btMJm670Wz90C+JpMJ7e3tmOKuwOkGAwCgqr4NKSkpNJDvOA/S2dHRYTVoJA3kmwVg/AfpLCkpwdmzZ2kgXw4G8s3Pz0d1dTV9R3AwkG9BQQGqq6vpO0LkgXwNhu5/p2kgX47ZOpjeaKWkpCApKQmMMUSs2CEsL9p1u90+k/TPkgURF+XAD8qCH5QFH+ydg621BxVwgxivAs5sNkMq7b6aHX7r68JyKuDGX+8siHgoB35QFvygLPhg7xxsrT1oT+CA5fTp+Tb857txbgkZKAsyvigHflAW/KAs+MBLDlTAcaD3/KtzpnsJv3/+y0kxmjOp9c6CiIdy4AdlwQ/Kgg+85EAFHAe0Wq3w+4sbfmP13O5v8se5NZNb7yyIeCgHflAW/KAs+MBLDlTAcSAwMFD43cPV0eq5v715ZLybM6n1zoKIh3LgB2XBD8qCD7zkQAUcByzdjS2cHR2E340mGtR3PJ2fBREH5cAPyoIflAUfeMmBCjgOZby20uoxXUYlhBBCSG9UwHEgJCTE6rFEIrF6vOmNw+PZnEnt/CyIOCgHflAW/KAs+MBLDjQTQz/GeyYGLy8vnD171mqU9UAPJSrPdvd0MTMmjDRNMzHYd5R1lUoFqVRKMzGIPMr66dOncerUKZqJgYOZGPLy8lBaWkrfERzMxFBYWIjS0lL6jhB5JgaNRoOOjg6aiYFn4z0Tw/l6D+p7zfxp2LL+cru1gXSjkc75QDnwg7LgB2XBB15mYqBLqBzrfSWVxoQjhBBCiAUVcByIi4vrd/mflieOc0vIQFmQ8UU58IOy4AdlwQdecqACjgOW+xjO9/trYq0e3/jXT8ejOZPaQFmQ8UU58IOy4AdlwQdecqACjgMdHR0DPufmrBR+P36yfjyaM6kNlgUZP5QDPygLflAWfOAlByrgOODi4jLgc99tWzaOLSGDZUHGD+XAD8qCH5QFH3jJgQo4Dgw2poyTysHq8Yb/fGfv5kxqvIzvM9lRDvygLPhBWfCBlxyogOOAZQyYgcik57qjUm9U+xoqCzI+KAd+UBb8oCz4wEsOVMBNAA//jsb9IYQQQsg5VMBxYOrUqYM+v+rKaKvHdBnVfobKgowPyoEflAU/KAs+8JIDTaXVj/GeSssyJcpg0+RIJYC5Z86Mz385iWWJTjSVFsZ+mhxHR0fI5XKaSkvkaXKqqqpQVVVFU2lxMJVWYWEhysvL6TuCg6m0Tpw4gfLycvqOEHkqLa1Wi87OTppKi2diT6XV25v7c/H0W0eEx0W7brdbeyYzmqqGD5QDPygLflAWfKCptMiwnH8ZlQb1JYQQQiYvKuA4MHv2bJvW690blQb1tQ9bsyD2RTnwg7LgB2XBB15yoAKOA5b7UYZCvVHtz9YsiH1RDvygLPhBWfCBlxyogOOATqezaT26jGp/tmZB7Ity4AdlwQ/Kgg+85EAFHAfUarXN67q50Nyo9jScLIj9UA78oCz4QVnwgZccqIDjQHh4uM3rHtp2ix1bQoaTBbEfyoEflAU/KAs+8JIDFXAcsIz9ZAtHpfXQfTSo79gaThbEfigHflAW/KAs+MBLDlTATUA0NyohhBAyuVEBx4GAgIBhrU+9Ue1nuFkQ+6Ac+EFZ8IOy4AMvOVABxwGZTDas9WluVPsZbhbEPigHflAW/KAs+MBLDlTAccAyd9xw0GVU+xhJFmTsUQ78oCz4QVnwgZccqICboOgyKiGEEDJ5yYdeZfLZvn07tm/fDpPJBABIT0+HWq1GXFwc8vPz0dHRARcXF4SEhCA7OxsAMHXqVJjNZlRWVgLonmqjuLgYOp0OarUa4eHhQs+VgIAAyGQyoYoPDw9Hfn4+WlpaoFKpEB0djYyMDACAv78/VCoVTp7sPssWExODU6dOIdLNeiDBRfe/i//+3yVwdnYWRomOiopCTU0NGhoaIJfLER8fj9TUVDDG4OXlBTc3NxQVFQEAIiIi0NDQgLq6OkilUiQmJiI9PR0mkwkeHh7w9vZGfn4+AGD69OloaWlBTU0NACApKQmZmZno6uqCm5sb/P39kZubCwAIDQ1Fe3s7zpw5AwBISEhATk4O9Ho9XF1dERQUhOPHjwMAgoODYTQacerUKQBAXFwcCgoK0N7eDmdnZ4SGhuLYsWMAgKCgIABARUUFAGDWrFkoKSmBTqeDk5MTIiMjkZmZKWxvuVyOsrIyAMDMmTNRUVGB5uZmqFQqxMTEID09HQDg6emJ+vp6lJSUAACio6NRVVWFxsZGODg4IC4uDikpKQAAHx8faDQanDhxQtjetbW1OHv2LGQyGRISEpCWlgaz2QwvLy+4u7ujsLBQyLyxsRF1dXWQSCSYO3cuMjIyYDQa4e7uDh8fH2F7h4WFQafTobq6GgAwd+5cZGVlwWAwQKvVIiAgADk5OQCAadOmQa/Xo6qqCgAQHx+P3Nxc6PV6aDQaBAcHW+2zJpNJ2N5z5sxBUVER2tra4OzsjLCwMGRlZQEAAgMDIZVKhX02NjYWpaWlaG1thaOjI6KiooTtPWXKFCgUCpSWlgrbu7KyEk1NTVAqlYiNjUVaWhoAwNfXF2q1WtjeM2bMQHV1NUwmEzIzM622t7e3N1xdXYXtHRkZifr6etTX1wv7rGV7e3p6wtPTEwUFBcI+29zcjNra2j77rLu7O3x9fZGXlyfss21tbcL2TkxMRHZ2Njo7O6HVahEYGCjssyEhITAYDDh9+rSwz9rjOyI2NhZlZWXD+o5oamqCQqHA7NmzkZqaKmzv4X5HmEwmpKSk0HdEz3eEn58fnJycRPmOkEgkSElJoe+I6mo0NDT02d7j9R0REBCA8vJyu31HWNo0FAljjNm05iTU0tICV1dXNDc3Q6PR2O1z8vPzERUVNezXRf1uB0zmc/EV7bp9LJs1KY00CzK2KAd+UBb8oCz4YO8cbK096BIqB1paWkb0OrqMOvZGmgUZW5QDPygLflAWfOAlByrgOKBSqUb0uvN7oybe8fZYNGdSG2kWZGxRDvygLPhBWfCBlxyogONAdHT00CsNwNnRQfi9uc2A3d/kj0WTJq3RZEHGDuXAD8qCH5QFH3jJgQo4DlhuRh6JzNdvs3r8xI5fRtucSW00WZCxQznwg7LgB2XBB15yoALuAqBxcrB6HLt6pzgNIYQQQsi4oAKOA/7+/qN6ffpr1mfh9AYT3Q83QqPNgowNyoEflAU/KAs+8JIDFXAcGIsbIq+ZP83qcXObgc7EjQAvN6dOdpQDPygLflAWfOAlByrgOGAZgHM0tqy/HP6eaqtleoMJM1buGPV7TyZjkQUZPcqBH5QFPygLPvCSAxVwF5BD226Bq1phtcxoYlTEEUIIIRcYmolhEOM1E0NbWxvUavXQK9powR/3oKq+zWqZXCZB3ttrx+wzLlRjnQUZGcqBH5QFPygLPtg7B5qJYQKxzDM3Vg5tuwUzp3laLTOaGN0TZ4OxzoKMDOXAD8qCH5QFH3jJgQo4DjQ1NY35e3701PV9ijjqnTo0e2RBho9y4AdlwQ/Kgg+85DCiAq6ystKqAk1NTcX999+P//73v2PWsMlEoVAMvdIIfPTU9X06NjS3GaiIG4S9siDDQznwg7LgB2XBB15yGNE9cJdeeinWrVuHlStXorq6GhEREYiOjsaJEydw77334vHHH7dHW8fN9u3bsX37dphMJhQVFeHgwYNQq9WIi4tDfn4+Ojo64OLigpCQEGRnZwMApk6dCrPZjMrKSgDA7NmzUVxcDJ1OB7VajfDwcBw9ehQAEBAQAJlMhvLycgDAzJkzUV5ejpaWFqhUKkRHRwsjPfv7+0OlUgm9XmJiYnDq1Ck0NTVBoVBg9uzZSE1NBQD4+vrC2dkZxcXFAICoqCjU1NTgtud/wdlWo9Xf6KVV4vOnr0ZRUREAICIiAg0NDairq4NUKkViYiLS09NhMpng4eEBb29v5Od3T9M1ffp0tLS0oKamBgCQlJSEzMxMdHV1wc3NDf7+/sjNzQUAhIaGor29HWfOnAEAJCQkICcnB3q9Hq6urggKCsLx48cBAMHBwTAajcJ/DuLi4lBQUID29nY4OzsjNDQUx44dAwAEBQUBACoqKgAAs2bNQklJCXQ6HZycnBAZGYnMzExhe8vlcpSVlQnbu6KiAs3NzVCpVIiJiUF6ejoAwM/PD05OTigpKQHQPWVKVVUVGhsb4eDggLi4OKSkpAAAfHx8oNFocOLECWF719bW4uzZs5DJZEhISEBaWhrMZjO8vLzg7u6OwsJCAEB4eDgaGxtRV1cHiUSCuXPnIiMjA0ajEe7u7vDx8RG2d1hYGHQ6HaqrqwEAc+fORVZWFgwGA7RaLQICApCTkwMAmDZtGvR6PaqqqgAA8fHxyM3NhV6vh0ajQXBwsNU+azKZhO09Z84cFBUVoa2tDc7OzggLC0NWVhYAIDAwEFKpVNhnY2NjUVpaitbWVjg6OiIqKkrY3lOmTIFCoUBpaamwvSsrK9HU1ASlUonY2FikpaUJ+6xarRa294wZM1BdXY2zZ89CoVBYbW9vb2+4uroK2zsyMhL19fWor68X9lnL9vb09ISnpycKCgqEfba5uRm1tbV99ll3d3f4+voiLy9P2Gfb2tqE7Z2YmIjs7Gx0dnZCq9UiMDBQ2GdDQkJgMBhw+vRpYZ+1x3dEbGwsysrK7PYd0dDQALlcjvj4eKSmpoIxBi8vL7i5uaGwsBASiYS+Izj4jrBsb/qOqEZDQ0Of7X2hfEcUFBRg4cKFQ94DN6ICzs3NDUeOHEFERAS2bduG9957Dz///DO+/vpr3Hnnndx0sR2t8erEkJKSgqSkJLu9P9A90X1zm8Fqmb+nGoe23WLXz51oxiMLMjTKgR+UBT8oCz7YOwe7dmLo6uqCUqkEAHzzzTe47rrrAHRXu5b/RRG+pL26ss8QI1X1bbjxr5+K1CJCCCGEjNSICrjo6Gi8/PLL+PHHH3HgwAFceeWVAICqqip4eHiMaQMnA19f33H5nLRXV0KlkFktO36yHgv+uGdcPn8iGK8syOAoB35QFvygLPjASw4jKuCeeeYZvPLKK1iwYAGWL1+OWbNmAQA+++wzzJ07d0wbOBk4OzuP22dl71zdp4irqm9DAnVsADC+WZCBUQ78oCz4QVnwgZccRlTALViwQLgxcMeOc6P8r1u3Di+//PKYNW6ysNxQPF6yd66GXCaxWtbSZkD4ra+Pazt4NN5ZkP5RDvygLPhBWfCBlxxGVMB1dHSgs7MTbm5uAIDy8nJs3boVhYWF8Pb2HtMGEvvIe3ttnyIOAMJvfR0zfkdTbxFCCCE8G1Ev1CuuuAI33HAD7rzzTjQ1NSEyMhIODg6or6/Hli1bcNddd9mjreNuvHqhtrS02PX9B9Nf71QLuVSCvHcm1/RbYmZBzqEc+EFZ8IOy4IO9c7BrL9TMzExceumlAIAPP/wQPj4+KC8vx1tvvYVt27aNrMWTmGWsJDGkvboSRbtu7/c5o5kh/NbXEbHidUyWKXPFzIKcQznwg7LgB2XBB15yGFEB197eDhcXFwDA119/jRtuuAFSqRQXXXSRMIgfsV1DQ4PYTUDRrtv7TL1lwRgQsWIHwm99HUse2TvOLRtfPGRBKAeeUBb8oCz4wEsOIyrgwsLC8Mknn6CyshJfffUVrrjiCgBAbW0tnd4dAblcLnYTAHRPvTVYIQcAeeUNCL/1dYTf+jqufujjC+7MHC9ZTHaUAz8oC35QFnzgJYcR3QP34Ycf4tZbb4XJZMKvf/1rHDhwAACwefNm/PDDD9i3b9+YN1QM43UPHK82/Oc7fP6L7bNqZLy2Ei5OfMwRRwghhExEttYeIyrgAKC6uhpnzpzBrFmzIJV2n8hLTU2FRqNBZGTkyFrNmfEq4FJTU7kfP2/Gyh0wmoa3qxx+6VZ4uDraqUX2MRGymAwoB35QFvygLPhg7xxsrT1GfB7Q19cXvr6+wiS3AQEBtGON0ES4DJn39rneqLYWc/Pu2mX1+P9+G4e7ls4Z87aNpYmQxWRAOfCDsuAHZcEHXnIYUQFnNpvx1FNP4Z///Cd0Oh0AwMXFBQ888AAeffRR4YwcsY2Xl5fYTRiW3sXc7m/y8cSOX2x63b8+yMS/PsgUHk8PcMPnzyyFRNJ3PDqxTLQsLlSUAz8oC35QFnzgJYcRFXCPPvooXn/9dfzjH//AxRdfDAD46aefsHHjRuj1ejz99NNj2sgLnWVA5IloeXIUlidHCY+HU9CdONWIiBXnBg32cnPCT/+5RdSCbiJncSGhHPhBWfCDsuADLzmM6B44f39/vPzyy7juuuusln/66ae4++67cfr06TFroJjG6x64lJQUJCUl2e39xRZx6+sYyQnnRXOD8cL9C8e8PYO50LOYKCgHflAW/KAs+GDvHOx6D1xDQ0O/HRUiIyO5GR+F8KPwvIGCbb2H7qvUMmF+VoVcimM7V0FGl+cJIYSQkZ2BS0pKQlJSUp9ZF+69916kpqYiJSVlzBoopvE6A9fU1AStVmu39+dd7Oqd0BtMw3rND/+5Bb7u6jFvy2TPgheUAz8oC35QFnywdw52HUbk+++/x9VXX42goCDMmzcPAHD48GFUVlbiyy+/FKbZmqi2b9+O7du3w2QyoaioCAcPHoRarUZcXBzy8/PR0dEBFxcXhISEIDs7GwAwdepUmM1mVFZWAgBmz56N4uJi6HQ6qNVqhIeH4+jRowC6e+zKZDJh1gqtVguz2YyWlhaoVCpER0cjIyMDQPflapVKhZMnu8dji4mJwalTp9DU1ASFQoHZs2cjNTUVQHfPYGdnZxQXFwMAoqKiUFNTg4aGBsjlcsTHxyM1NRWMMXh5ecHNzQ1FRUUAgIiICDQ0NKCurg5SqRSJiYlIT0+HyWSCh4cHvL29kZ+fDwCYPn06WlpahOlEkpKSkJmZia6uLri5ucHf3x+5ubkAgNDQULS3t+PMmTMAgISEBOTk5ECv18PV1RVBQUE4fvw4ACA4OBhGoxG/fmg/hrtXfvTk1WAd9ehob4OTkxMiIyORmZkpbG+5XI6ysjIAwMyZM1FRUYHm5maoVCrExMQgPT0dQPcAjVOnTkVJSQkAIDo6GlVVVWhsbISDgwPi4uKE/6D4+PhAo9HgxIkTwvaura3F2bNnIZPJkJCQgLS0NJjNZnh5ecHd3R2FhYUAgPDwcDQ2NqKurg4SiQRz585FRkYGjEYj3N3d4ePjI2zvsLAw6HQ6VFdXAwDmzp2LrKwsGAwGaLVaBAQEICcnBwAwbdo06PV6VFVVAQDi4+ORm5sLvV4PjUaD4OBgq33WZDIJPcnnzJmDoqIitLW1wdnZGWFhYcjKygIABAYGQiqVCvtsbGwsSktL0draCkdHR0RFRQnbe8qUKVAoFCgtLRW2d2VlJZqamqBUKhEbG4u0tDRhn1Wr1cL2njFjBqqrq1FZWQmtVmu1vb29veHq6ips78jISNTX16O+vl7YZy3b29PTE56enigoKBD22ebmZtTW1vbZZ93d3eHr64u8vDxhn21raxO2d2JiIrKzs9HZ2QmtVovAwEBhnw0JCYHBYBBuG7HXd0RsbCzKyspE+Y44evQo1Go1V98Rln02Li4OBQUFaG9vh7OzM0JDQ3Hs2DEAQFBQEACgoqICADBr1iyUlJRAp9ON6jvCz88PTk5OonxHfP/991CpVPQdUV2NhoaGPtt7vL4jnJycoNFo7PYdUVBQgIULF9pvHLiqqips375d+OOjoqKwbt06PPXUU/jvf/87krfkDt0DJ74Ff9yDqvq2Eb323ceuwlRfV3hpHW3uGEFZ8IFy4AdlwQ/Kgg8T+h44oPt/fef3Nj127Bhef/31C6aAGy807MrADm27Rfh9OD1cAWDF374c8LlX/vQbTPF0hp+HM5wdHYQCj7LgA+XAD8qCH5QFH3jJYcRn4Ppz7NgxxMXFwWQa3v1MvJrsU2nxbrgF3XD9697L4enqCE9XR3hpHeHipOBqzDpCCCEXHrtPpdUfKuBGJj09HQkJCXZ7/8nG0nN1vGjVSjy/fgH8PNTwdVdbndEjI0PHBD8oC35QFnywdw52v4RKxs6FUvDyoui8YUss7HXGrqmtE79/5qtB15HLJNj6x1/DW+sET60jvFwdoVTQ4TcQOib4QVnwg7LgAy85DOtfkBtuuGHQ55uamkbTlknLw8ND7CZMCufPGtGff+/+Adv/d2LMP9toYlj/r4NDrieXSvCPOy+Dl5sTvFwd4aV1gkY9+S7d0jHBD8qCH5QFH3jJYViXUNesWWPTem+88caIG8ST8bqE2tLSQvfYcWI4WYxk/Lqx5O+pxhOr53ef0dM6wUPjCAc5HzfXjhYdE/ygLPhBWfDB3jmIcg/chYaGEZl8xjqL0QyDMpZuWhCO6y4OhYerIzw0Krg6K7me1YKOCX5QFvygLPgw4YcRIYQMrfcwKIOxd6H34aEifHioyKZ1tWoFNt95GTw0jnDXqODh6gi1ysFubSOEEDJ8dAZuEON1Bq6hoQHu7u52e39iu4mQReSK12Hm5KiVSoBt9y+El6sjPFwd4aiUQyGXjfq+vYmQw2RBWfCDsuCDvXOgM3ATSEtLCx2UnJgIWRS8238v2/Pd+NdPcfxkvV3bYmawqXMG0F3sOSrlaNMb4eaixIEtvx1wbL2JkMNkQVnwg7LgAy85UAHHgZqaGgQHB4vdDIILK4uPnrp+WOvPWLkDRpP9Tu2ZGdCmNwIAGls7kXDHO4Ou76pW4MnbLwEAuLkowQDUNLRhipcL3JyVmOavhVQ6uXrniuFCOiYmOsqCD7zkQAUcIQQAkPf2WpvXHY+ze81tBty37dsxez+pBPjLbRfBWeUAhYMMXUYzDEYTzGYGpYMM0/y1CPR2AQPg4ugAmUwKuaz/jh56gxEqGsePECIiugduEDSVFiGjx0tP3PHipXVEbKgXvs+qhNHE4CCX4vpLQvGr2UGQSSWQySQ4VauDn4ca0wPcIJdJoHCQwUnlAJPJjFN1rQjy1sDZSQHGGFrbDdColWL/WQCA1nYDHJXyAQtbQsjo0TAiY2C8CrjMzEzExcXZ7f2J7SgL8Yk9vt5EpXVW4tWHFiEq2B0Kucym1+gNRqTkncFFM/wGnBnEckw0tHRg/l27YGbANfOnYcv6y8ey+cQG9P3EB3vnQJ0YJpCuri6xm0B6UBbiy965eszGWRqPS728aNJ14rePf9bvc39ZmYTVi2P6LF/51Jc4VlyHxEgfvPv4Nf2+1nJMHMk7I/R+3neklAo4EdD3Ex94yYEKOA64ubmJ3QTSg7Lgw1jlMNyOHCPFe6H497dT8Pe3U4THIb4afLXltzhWXAcASCuowe5v8vudas6ShaxXh5HFF4XYucWkP/T9xAdecqBLqIMYr0uoOp0Ozs7Odnt/YjvKgg+Ug20S73gbzW2GMXkvlUKG7J2r+yy3ZLE/pRR//Hd3p5KCd9ZSD2AR0HHBB3vnYGvtQXeiciA3N1fsJpAelAUfKAfbpL26EkW7bu/3Z+Y0z2G9l95gwob/fNdnuSWL3v/VN/EykvQkQ8cFH3jJgS6hEkLIBej8y8e2jPP3+S8n8c97FmDFk1+gsLIRD96SgGkufdczm82g//8TIi46AjkQGhoqdhNID8qCD5TD2Mt7e61wdk6lGLiXatztbyG9sAat7QZseS9DyILhXPG3ZvN+u7eX9EXHBR94yYEKOA60t7eL3QTSg7LgA+VgX9k7V6No1+1wVSv6PGeZLaP7d0O/WWQU1ti1faR/dFzwgZccqIDjwJkzZ8RuAulBWfCBchgflnvo/D3U/T5vNDG8+r9sANb3wDEAu7/JH4cWkt7ouOADLzlQAUcIIZPcoRdu6fdsHAC893MNVjz5BczndVx45bPs8WgaIWQA1ImhH9u3b8f27dthMnWPBp+eng61Wo24uDjk5+ejo6MDLi4uCAkJQXZ295fY1KlTYTabUVlZCQCYPXs2iouLodPpoFarER4ejqNHjwIAAgICIJPJUF5eDgCIjo5Gfn4+WlpaoFKpEB0djYyMDACAv78/VCoVTp48CQCIiYnBqVOn0NTUBIVCgdmzZyM1NRUA4OvrC2dnZxQXFwMAoqKiUFNTg4aGBsjlcsTHxyM1NRWMMXh5ecHNzQ1FRUUAgIiICDQ0NKCurg5SqRSJiYlIT0+HyWSCh4cHvL29kZ/f/T/u6dOno6WlBTU13ZdRkpKSkJmZia6uLri5ucHf31/opRMaGor29nbhfywJCQnIycmBXq+Hq6srgoKCcPz4cQBAcHAwjEYjTp06BQCIi4tDQUEB2tvb4ezsjNDQUBw7dgwAEBQUBACoqKgAAMyaNQslJSXQ6XRwcnJCZGQkMjMzhe0tl8tRVlYGAJg5cyYqKirQ3NwMlUqFmJgYpKenC9uwvr4eJSUlQjZVVVVobGyEg4MD4uLikJLSPZ6Wj48PNBoNTpw4IWzv2tpanD17FjKZDAkJCUhLS4PZbIaXlxfc3d1RWFgIAAgPD0djYyPq6uogkUgwd+5cZGRkwGg0wt3dHT4+PsL2DgsLg06nQ3V1NQBg7ty5yMrKgsFggFarRUBAAHJycgAA06ZNg16vR1VVFQAgPj4eubm50Ov10Gg0CA4OttpnTSaTsL3nzJmDoqIitLW1wdnZGWFhYcjKygIABAYGQiqVCvtsbGwsSktL0draCkdHR0RFRQnbe8qUKVAoFCgtLRW2d2VlJZqamqBUKhEbG4u0tDRhe6vVamF7z5gxA9XV1WCMCaOdW7a3t7c3XF1dhe0dGRmJ+vp61NfXC/usZXt7enrC09MTBQUFwj7b3NyM2traPvusu7s7fH19kZeXJ+yzbW1twvZOTExEdnY2Ojs7odVqERgYKOyzISEhMBgMOH36tLDP2uM7IjY2FmVlZXb9jnj5zhm48+W8foclSSuohrfWusD7w3Wxk/I7ws/PD05OTqJ8R8hkMqSkpNB3RHU1Ghoa+mzv8fqOCAkJQXl5ud2+IyxtGgqNAzeI8RoH7tixY5g1a5bd3p/YjrLgA+UgrsgVO2A+758GB7kUXUaz8HjT2vn9DvxL7IeOCz7YOwcaB24C0ev1YjeB9KAs+EA5iOux1Rf1Wda7eAPoEqoY6LjgAy85UAHHAVdXV7GbQHpQFnygHMS14jczEBMy+EDAbi7KcWoNsaDjgg+85EAFHAcs92oQ8VEWfKAcxPfx09fjisSpAz6fX3Z2HFtDADoueMFLDlTAccBysyMRH2XBB8qBD//5v2QMNOWpo5L6wI03Oi74wEsOVMARQggZ0G0L/Ppd3qY30lhwhIiICjgOBAcHi90E0oOy4APlwI/br4tHxmsr+31u8zvdw5MwxvD5LyWoqGkZz6ZNOnRc8IGXHOgcOAeMRuPQK5FxQVnwgXLgh9FohIuTAlIJcN5YvtAbus/CvfDxUdQ3dQAArpk/DVvWX2613kMvfo99KaV44JYErF4cM15Nv+DQccEHXnKgM3AcsAySSMRHWfCBcuCHJYuCd2/HxjXz+zz/97ePCMUbAOw7UtpnnU9+KkZnlwn/3JNmv4ZOAnRc8IGXHKiAI4QQYpNbfxOFL5+70WpZZ5f1+HAzgj0GfP35Z/AIISNHBRwH4uLixG4C6UFZ8IFy4Mf5WYRN0UIyQM9UACiqbIDRZEbx6SacP9FPVJC7PZo4adBxwQdecqACjgO2zntG7I+y4APlwI/+sujvUqpFZ5cZF9+9C1f96SPc/MT/rJ4L9udjANSJio4LPvCSAxVwHGhvbxe7CaQHZcEHyoEf/WWxPDkK0YNcKm1s7QQAHCuus1pedoZ6qY4GHRd84CUHKuA44OzsLHYTSA/Kgg+UAz8GymLv35dAPtAovwPILqmjseNGgY4LPvCSAxVwHAgNDRW7CaQHZcEHyoEfg2WR987aQYu4/p555bPsMWjV5ETHBR94yYEKOA4cO3ZM7CaQHpQFHygHfgyVRd47a+E0wLRaDMD1j+y1WvaH62LHqmmTDh0XfOAlByrgCCGEjMqfV8zt92wbAOSXN4xrWwiZLKiA40BQUJDYTSA9KAs+UA78sCWL5clR2Lh2/oAT3/e25b2MMWjV5ETHBR94yYEKOEIIIaO2PDkKBe/ePuR6bR2GcWgNIRc+KuA4UFFRIXYTSA/Kgg+UAz+Gm8XMaZ6DPm80s357oj788g+YtfpNvL0/d1ifN5nQccEHXnKgAo4QQsiY+eip61G063Yo5AP/8/K3nYcRfuvrmLFyh1DMffzDCXQYjHiO5kslxCZUwHFg1qxZYjeB9KAs+EA58GOkWTx620UDPmfsmRTVaGLY/E6K9XNGc38vIaDjghe85EAFHAdKSkrEbgLpQVnwgXLgx0izWJ4chU1r58PFyWHQ9fQGE27Z+Pm5BcMbG3hSoeOCD7zkQAUcB3Q6ndhNID0oCz5QDvwYTRbLk6OQ8dptuGb+tEHXyyyqOfeAAW36LtQ0to34cy9UdFzwgZccqIDjgJOTk9hNID0oCz5QDvwYiyy2rL8cOW+utm1lCXDFhg9w6T17cMezX1k99dBL3yP+9rew68DknI6Ljgs+8JIDFXAciIyMFLsJpAdlwQfKgR9jlYXCQQZ/D/WQ6xlNDHVNHQCA77NOWT33yY/FaO3owvOTtKMDHRd84CUHKuA4kJmZKXYTSA/Kgg+UAz/GMotDL9wCjVoxotcaTec6NzDGxqpJEwodF3zgJQcq4AghhIyb9FdXDnlPnEXvmR3aOrqE373dhj6TR8iFjgo4DgQEBIjdBNKDsuAD5cAPe2SxZf3l2LR2/pDrSSXnKrjWXjM4lJ5p7ncwYADoNBhxOLcKBqNp9A3lDB0XfOAlByrgOCCXy8VuAulBWfCBcuCHvbJYnhw15Jm43hdKdb3OwAEDz6m64skvsOrpfbjjma9H20Tu0HHBB15yoAKOA2VlZWI3gfSgLPhAOfDDnllsWX85Nq4Z+Eycycyw4T/f4eo/f4zrHt5r9Vybvv85VbNP1gMADudWDXiWbqKi44IPvORABRwhhBDR3PqbqEHnT/38l5M4UdnYZ7nRNHRHhlc+yx5V2wjhmYRN1u48NmhpaYGrqyuam5uh0Wjs9jnt7e3cjCsz2VEWfKAc+DGeWSz9yyfILTtr8/p+HmqcOdsGhVyKR2+7CMuToxB+6+vC85vWzsfy5Ch7NHVMGE1mLP3LJ6iq1+HB5YlDtpWOCz7YOwdbaw8+LuRyZvv27di+fTtMpu6bYNPT06FWqxEXF4f8/Hx0dHTAxcUFISEhyM7u/h/e1KlTYTabUVlZCQCYPXs2iouLodPpoFarER4ejqNHjwLovgFSJpOhvLwcQPeggHK5HC0tLVCpVIiOjkZGRvf9Hf7+/lCpVDh58iQAICYmBqdOnUJTUxMUCgVmz56N1NRUAICvry+cnZ1RXFwMAIiKikJNTQ0aGhogl8sRHx+P1NRUMMbg5eUFNzc3FBUVAQAiIiLQ0NCAuro6SKVSJCYmIj09HSaTCR4eHvD29kZ+fvfliOnTp6OlpQU1Nd2jpyclJSEzMxNdXV1wc3ODv78/cnNzAQChoaFob2/HmTNnAAAJCQnIycmBXq+Hq6srgoKCcPz4cQBAcHAwjEYjTp3qHvspLi4OBQUFaG9vh7OzM0JDQ3Hs2DEAQFBQEACgoqICQPfcdCUlJdDpdHByckJkZKTQ1TsgIAByuVw47T1z5kxUVFSgubkZKpUKMTExSE9PF/IPDQ0VpkqJjo5GVVUVGhsb4eDggLi4OKSkdM/d6OPjA41GgxMnTgjbu7a2FmfPnoVMJkNCQgLS0tJgNpvh5eUFd3d3FBYWAgDCw8PR2NiIuro6SCQSzJ07FxkZGTAajXB3d4ePj4+wvcPCwqDT6VBdXQ0AmDt3LrKysmAwGKDVahEQEICcnBwAwLRp06DX61FVVQUAiI+PR25uLvR6PTQaDYKDg632WZPJJGzvOXPmoKioCG1tbXB2dkZYWBiysrIAAIGBgZBKpcI+Gxsbi9LSUrS2tsLR0RFRUVHC9p4yZQoUCgVKS0uF7V1ZWYmmpiYolUrExsYiLS1N2GfVarWwvWfMmIHq6mqUl5fD3d3dant7e3vD1dVV2N6RkZGor69HfX29sM9atrenpyc8PT1RUFAg7LPNzc2ora3ts8+6u7vD19cXeXl5Qv5tbW3C9k5MTER2djY6Ozuh1WoRGBgo7LMhISEwGAw4ffq0sM/a4zsiNjYWZWVlonxHZGRkwMXFZVy+Ix6+3g93vNQIvcG2+VDPnO2ercFgNOO5XSlYekmI1fMLorXC/jNW3xF+fn5wcnIak++IFok3CnvOLm7/OAOLE/wG/Y44cuQI1Go1fUdUV6OhoaHP9h6v7wgHBwd4eHjY7TvC0qah0Bm4QYzXGbiUlBQkJSXZ7f2J7SgLPlAO/BAri5mr3kBnl+0T26sUcmTvXNXvGbhjxbU4c7YNsaFe8Pd0HvR9uoxmZBbVYFaYF1QK+57j+OLwSfzfC99ZtXUwdFzwwd452Fp70D1wHFCpVGI3gfSgLPhAOfBDrCyOv7lm0HvjzifpZ5nlHrjlGz/HH//9LRb88T0se+KzQd/nxb1HsfKpLzF33Tt27wTR+/yJLZd66bjgAy85UAHHgZiYGLGbQHpQFnygHPghZhYfPXW9UMTJpf2VaOd0GIxY8hfrnqrrro2F2cxgNJ8rlI6eqLNa52hRDfYdKRWKqZ37ui/t6g2mAYcqsQdbikU6LvjASw5UwHGg9/1XRFyUBR8oB36IncVHT12Pol23I++dtXj0tosGXTevrMHq8dNvH8Fb+3Otlp1fCC5/8gvct+1bXPnghwAA8zjeVdT7o2zpMSt2FqQbLzlQAUcIIWRCWHVlNDatnQ+5bPCzcRZdRjP++Z71xPe9z8YZTWaYex6XnmkBAOgNRuH5DcviR9vkQfW+hPqH62Lt+lnkwkMFHAf8/PzEbgLpQVnwgXLgB29ZLE+OQt7ba+GqVti0/vkdIVQKmfB7s65T+N0ya5ejcvwGZxju2T7espiseMmBCjgO0Lg+/KAs+EA58IPXLNJeXYmiXbdj09r5GOL2OCt6g0m436yxVS8st/Q47f1W4zkQsC2fxWsWkw0vOVABxwHL+DZEfJQFHygHfvCexfLkKBS8e/uwijhLsdS7gNM6KwEAHZ3nLqHa+7Jm7xNw664d+rN4z2Ky4CUHGsiXEELIhFfw7u0Auntz/u3Nw4NOtaVr77502tjrEqplUGAnlQN0HV0AbBvaYzR63wO3bGGEXT+LXHjoDBwHoqOjxW4C6UFZ8IFy4MdEy8Jyj9w186cNuE5LexfCb30d6/91UFhmOYPnpHIQlm153769DXufgdtlwzAiEy2LCxUvOVABxwHLlCZEfJQFHygHfkzULLasvxxFu263ubODmQEL7t2Djs4uYdnLnxzDrgP5MJsZfj5+GvXNHWPaxt7nCF/73/Eh15+oWVxoeMmBLqFyoLGxUewmkB6UBR8oB35M9CzSXl0JoPvS6sYdv2Cwfp9VPZdRe9v4xi/Y/M4RoTfrNfOnYcv6y8ekbb0vofYevmQgEz2LCwUvOdAZOA44ODgMvRIZF5QFHygHflwoWSxPjkLhruF1drDoPRTJ57+cBAD86/0MzF77Jv772bERt6l3MdnY2jngehYXShYTHS85UAHHgbi4OLGbQHpQFnygHPhxoWVR8O7tKNp1O/w91CN+j93f5OOlT7LQrjdi6wcjn26LDXMcuAsti4mKlxyogONASkqK2E0gPSgLPlAO/LhQszj0wi0o2tVdzA1X7zHbjCY24knvhztr14WaxUTDSw5UwBFCCJnUNq2dj+FcWT17XmeGoSa9b9d3Wc36YDHcM3CE9EYFHAd8fHzEbgLpQVnwgXLgx2TIwnJ/3Mxpnjat39llsn7cqwPCume/xpy1bwln5bZ9mInZa99C4rp3sOML656m50+lNdSZvMmQxUTASw5UwHFAo9GI3QTSg7LgA+XAj8mUxUdPXS9Mz6V0sP2fxy6TGVds+ADht76OQ1mVaNN34Z97useQ2/7xUWG9Fz7KtHqd6bzBhoeaTmsyZcEzXnKgAo4DJ06cELsJpAdlwQfKgR+TMYvlyVE4/uYaXD0vxKb1TWaGsuoWq2Ut7Qbs/ibfqqep3mCC2XxuidFktnrNUFN32ZLF+n8dRNIf3hnxfXlkaLwcE1TAEUIIIf34172/tnkg4P5sfifV6rHJzPDoqz8Kj3sXc8DwOzWcr6CiAV+nlaGxtXPIs3lk4qMCjgNRUfadb4/YjrLgA+XAj8meRdqrK4XLqvJhDiLX2dV3cN7//XxuInTjeQXcUGPKDZVFS9u5jhLrrh38bB4ZOV6OCSrgOFBbWyt2E0gPyoIPlAM/KItuy5OjkPfO2mGdkbt6Xt/5WHuXbKbzLqHOCvMe9P2GysJBdu6f9Gd2pdJlVDvh5ZigAo4DZ8+eFbsJpAdlwQfKgR+UhTXLGblr5vctzs7X3yVRZ8dzo/ibzjsD9+OxU4O+31BZSHudIezoNNJlVDvh5ZigAo4DMplM7CaQHpQFHygHflAW/duy/nJsWjt/0Km5vjh8ss+y3lNmmczWZ+B0+q5BP3OoLM4vGIfqFEFGhpdjggo4DiQkJIjdBNKDsuAD5cAPymJgy5OjUPDu7VApRvYPutHU9xTdYJc9h8ri/HHllifzca/WhYaXY4IKOA6kpaWJ3QTSg7LgA+XAD8piaNk7VwvTcs0Idh9y/Rv/+imAvr1QgcFndRgqi/PvqaN74OyDl2OCCjgOmM87jU7EQ1nwgXLgB2UxPJ/8fSk2rZ0/6DrHT9YDAHJL6/s819E58GXUobLoGubAwOf71/sZSPj923j7q9xhvW6y4eWYoAKOA15eXmI3gfSgLPhAOfCDshi+5clRKNp1+6A9VsNvfR3ZJXV9lhuM3cWByWzGV6llqG1sF54bKottH1qfvRvuPXAvfZKFlnYD/v1B5tArT2K8HBNysRtAAHf3oU+5k/FBWfCBcuAHZTFyaa+uBNB9KfPJnYf79Dpt7+w7TpzFhhcOYV9KKQDgmvnTsGX95VZZ/O3Nw/jfzyX4v5vjhXvdMoush7cY+T1woxxR+ALHyzFBZ+A4UFhYKHYTSA/Kgg+UAz8oi9FbnhyFfBvHkHOQSVF2plko3oBzvVl7Z/H2V3lo0nXi+d1jfz9WS/vgvWEnO16OCSrgCCGEkHGQ9urKIe+P6zKZccUDH1otk/WMVZJ2ohlHcqugN5w7c9fZZRrzdg5zwgkiErqEyoHw8HCxm0B6UBZ8oBz4QVmMreXJUcKlzdjVO6E3DF2AGU0Mx4prse3LSuDLSkQGudm1jd5uTnZ9/4mOl2OCzsBxoLGxUewmkB6UBR8oB35QFvbTe/iRmBCPQdf97eP/E34vqDiXiaU/5PmDAo/G2WZ9n2WlZ5rRrOvsZ+3Jh5djggo4DtTV9e2JRMRBWfCBcuAHZTE+Pn56CYp2DX9QYHPP2G8d/XSIuH/btza/j1UBeN4l1J+Pn8aiBz5E4rp3sOLJL4bVvgsRL8cEFXAckEjohgNeUBZ8oBz4QVmML8tZOVvmWgUAS8fW3lN0WXx55FxHCLOZ4c7nD2Dene/2O8Bv7wIwOtj6bOD73527aT+toNrquTZ9F/7474O49J7dk2bgYF6OCSrgODB37lyxm0B6UBZ8oBz4QVmIY8v6y1G063b4e6qHXDf81tdx5YMf9vvctX/+GADwS85pfJtZgbMt+n5ne+hdwJ1fnzjIBi4Vnt+dhv0pZahpbB/2wMETFS/HBBVwHMjIGHjqFDK+KAs+UA78oCzEdWjbLcJ9ctN8HAdcr8vY/z1whZWNuH/bt6g62yYsa9MbAACMMdy/7VvMu+tdvP/tubNsR0/UYdeBc2fTzp9jtbdfcqqE34c7cPBExcsxQQUcB4zGgQdzJOOLsuAD5cAPyoIfm24JHdF9cl8eKcVfX/1JeGw0MWz4z3fILKrFl0dKcbZZj5c/zbJ6zeZ3jgi/m3pN0dV7iJGCigaUnmkWHg80cPBXqWU4WdU0rDbzjJdjggo4DvAyqjOhLHhBOfCDsuCHJQvLfXJDjSk3mC8On8TpulbhsaHL+gxeZ6/HvWeQsPx637Zvcd3De61e0989cD8fP417tx7ElQ9+dMHcI8fLMUEFHAd8fHzEbgLpQVnwgXLgB2XBj/OzsMy5ainmhnNmjjHgzy//IDzWOiutnneQnysPevdQlfecgtvXq4OERX/31mWdqB30+bGw9h/7MWvNm3j3QJ5d3v98vBwTVMBxID//wvhfyYWAsuAD5cAPyoIfg2WxPDnKalw5W3qx9j6z1njeGG8xIZ4oqGiA2cys1pPLpH0utw6m991znV32ufT4U/ZpdHQa7TKtWH94OSaogCOEEEIuMFvWXz6qS6xHT9Tiuof3YuaqncgtrReW67tMA55J27Asvs8yc6/iT28wDXgZVddhwO3/2I+L79414kutg/S1uCBRAceBsLAwsZtAelAWfKAc+EFZ8GO4WVgusXZfXh3ZzJldJjPqmjps/jwAaNZ14nBuFRhjOL+m2vJeer+vfXFvFn7MPo26po4RX2q9en7oiF43XLwcE1TAcUCn04ndBNKDsuAD5cAPyoIfI82i+/LqqmENENybWmVb8WcZeuSGv36KVU/vw3WP7AUzW5dwnV39z/1aVHlueqqW9k6bz8SxXqfd5kX72dTO0eLlmKACjgPV1dVDr0TGBWXBB8qBH5QFP8YiC8sAwZvWzj9/xqwBteltu3ft2V2pAIDK2u7erYUVjX3GkFPI++9o0bvTBGNAXVMHXvns2JCf2fv9x+sSKi/HBBVwhBBCyCSzPDkKhT0dHmyZ7cEWBqP12TWFXNqngFs3jMF+pTZMWdX78u7Px0/b/N4XgpFdFCdjipdpOQhlwQvKgR+UBT/slcWhbbcIv2/4z3f4/JeTI3uj886AmRnrc1bMaOp/xoj+nKob+lLlxh2/CL9/m1lh83uPBi/HBJ2B40BWVpbYTSA9KAs+UA78oCz4MR5ZWC6xFu26HVLp8CZtZwDa9V3CY6OJIa/srNU6Wz/IxIb/fAcAqKrX4dOfimE0mXGq16DCQ/kuswIlp5sAAMdP1gnLp/pqhtXekeLlmKAzcBwwGAxiN4H0oCz4QDnwg7Lgx3hnUfDOWuz+Jh9P9DrLNRiTmWH22reslv3Sz2XNfUdKUXyqEQUV3R0X/vTi9/2+X6CXc59lmUU1+MPzBwAAm9bOtzrDV17dYlM7R4uXY4LOwHFAq9WK3QTSg7LgA+XAD8qCH2JkMdqhSPrrV+ColAvF22A8tY5Wj5t1nbhl4+fC41c+y7a6JBvk7TLs9o0EL8cEnYHjQEBAgNhNID0oCz5QDvygLPghZhbLk6OEcd52f5OPLe9loE3fNax72ix0HV1Dr4TuXqy97dyXY/XYzVmJqvpz98mdPNOM1nYDXJwUw27TcPByTNAZOA7k5OQMvRIZF5QFHygHflAW/OAli+XJUUh79XfIe3vNiO+Xs0V7p9FqLLgOg/VwJvnl1vfX6Tq6EP/7t0c8k4OteMmBCjhCCCGEjErBO2sxc5rnmL9v71kZnJQOVs/FRfQ/qXzv8ePONndgxZNf4JJ7dtu9sBtvF/wl1MrKSqxcuRK1tbWQy+V47LHH8Nvf/lbsZlmZNm34I2MT+6As+EA58IOy4AfvWXz01PVWj2f8bgeM5rEbXVepsB4EOL2gpt/1YkO9hd9f/CQLaQXdA+9ueS9DuAw8GrzkcMGfgZPL5di6dSvy8vLw9ddf4/7770dbW5vYzbKi1+vFbgLpQVnwgXLgB2XBj4mWRV6vs3Ijv8DKEPW7HXjqzcPoNPQ/Ddf59qeU4t0DeQCAH4+dEpZ3dHZh1po38frn2SNuDcBPDhd8Aefn54fZs2cDAHx9feHp6YmGhgZxG3WeqqoqsZtAelAWfKAc+EFZ8GMiZvHRU9ejaNftKOzpyWrL7Aq9NbcZYDIzvPVVHjKL+j/j1p+/v53SZ5nBaEZHpxFb3s/o5xVAXtlZPPCf7xB/+1tCAdgfXnIQvYD74YcfcO2118Lf3x8SiQSffPJJn3W2b9+O4OBgqFQqJCUlITU1dUSflZGRAZPJhMDAwFG2mhBCCCHDsTw5CgXvrhWGJXFVK4bV+eGXHNsLpy5jd+/Y1va+Y7b113O2vrkDS/7yCf73y0m0dnTh3x9k2vxZYhH9Hri2tjbMmjULa9euxQ033NDn+ffeew8bNmzAyy+/jKSkJGzduhWLFi1CYWEhvL27r3PPnj0bRmPfyXa//vpr+Pv7AwAaGhpw22234dVXX7XvHzQC8fHxYjeB9KAs+EA58IOy4MeFlEXvYUlGNX3XEJp0nX2W9T4LuO3DTLzy2THIZdbns3zcnAZ8T15yEP0M3OLFi/HUU09h6dKl/T6/ZcsW3HHHHVizZg1mzJiBl19+GU5OTtixY4ewTlZWFnJycvr8WIq3zs5OLFmyBA8//DDmz58/YFs6OzvR0tJi9TMecnNzx+VzyNAoCz5QDvygLPhxoWbRe/qua+ZPg0wqgbOjw9AvHITGqfv1rs7KPs/1Puv3+hfH0dVzabW3E6eaBnxvXnIQ/QzcYAwGAzIyMvDII48Iy6RSKZKTk3H48GGb3oMxhtWrV+PXv/41Vq5cOei6mzdvxqZNm/osT09Ph1qtRlxcHPLz89HR0QEXFxeEhIQgO7v7ZsipU6fCbDajsrISQPdZweLiYuh0OqjVaoSHh+Po0aMAugcBlMlkKC8vBwCYTCbk5+ejpaUFKpUK0dHRyMjovkbv7+8PlUqFkye7/3cSExODU6dOoampCQqFArNnzxYuKfv6+sLZ2RnFxcUAgKioKNTU1KChoQFyuRzx8fFITU0FYwxeXl5wc3NDUVERACAiIgINDQ2oq6uDVCpFYmIi0tPTYTKZ4OHhAW9vb+Tnd3fBnj59OlpaWlBT030/QlJSEjIzM9HV1QU3Nzf4+/sLO3hoaCja29tx5swZAEBCQgJycnKg1+vh6uqKoKAgHD9+HAAQHBwMo9GIU6e6bzqNi4tDQUEB2tvb4ezsjNDQUBw71t09PCgoCABQUdE9efGsWbNQUlICnU4HJycnREZGIjMzU9jecrkcZWVlAICZM2eioqICzc3NUKlUiImJQXp6OgCgo6MD9fX1KCkpAQBER0ejqqoKjY2NcHBwQFxcHFJSuu+t8PHxgUajwYkTJ4TtXVtbi7Nnz0ImkyEhIQFpaWkwm83w8vKCu7s7CgsLAQDh4eFobGxEXV0dJBIJ5s6di4yMDBiNRri7u8PHx0fY3mFhYdDpdKiu7u5JNXfuXGRlZcFgMECr1SIgIEAYl2jatGnQ6/XCPRrx8fHIzc2FXq+HRqNBcHCw1T5rMpmE7T1nzhwUFRWhra0Nzs7OCAsLE+b8CwwMhFQqFfbZ2NhYlJaWorW1FY6OjoiKihK295QpU6BQKFBaWips78rKSjQ1NUGpVCI2NhZpaWnCPqtWq4XtPWPGDFRXV+PMmTMwmUxW29vb2xuurq7C9o6MjER9fT3q6+uFfdayvT09PeHp6YmCggJhn21ubkZtbW2ffdbd3R2+vr7Iy8sT9tm2tjZheycmJiI7OxudnZ3QarUIDAwU9tmQkBAYDAacPn1a2Gft8R0RGxuLsrIyUb4jzpw5A71eT98RPd8Rfn5+cHJyEuU7orq6Gnq9/oL+jrjtEjc8f/evkJaWhvt3FOBsa9+ra7Zoae/CvVu+hkre9xJtl9GMkpISeHp69incLMyMDfgdYTAYUF5ebrfvCMv31lAkjLGx6+M7ShKJBHv37sWSJUsAdN8oOGXKFPzyyy+YN2+esN5DDz2E77//XjhIBvPTTz/hsssuQ2xsrLDs7bffxsyZM/us29nZic7Oc6dbW1paEBgYiObmZmg09pskNz8/H1FRo+/aTEaPsuAD5cAPyoIfkzWLBX/cg6r6kY0e4aSUo72fIi3Ayxnf/nsZwm99vd/X+XuoceiFW/p9zt45tLS0wNXVdcjag+szcGPhkksugdls21QfSqUSSmXf0632FhwcPO6fSfpHWfCBcuAHZcGPyZrFoW3dhZRlCq/OLiMMXSbYMsRcf8UbAJyq0/W73GLmNK8Bn+MlB9HvgRuMp6cnZDKZcBreoqamBr6+viK1auxZTp8S8VEWfKAc+EFZ8GOyZ2GZwit752oUvNt9v9wwRyWxciR34F6tX6WVDThzAy85cF3AKRQKxMfH4+DBg8Iys9mMgwcPWl1SJYQQQsjksmX95Sh8t3tIEn9PZ6Hjgq1ue3rfoM9vfmdkQ5aNF9Evoep0OuGGWgAoLS1FVlYW3N3dERQUhA0bNmDVqlVISEjA3LlzsXXrVrS1tWHNmjUitnpsTZ06VewmkB6UBR8oB35QFvygLPrXe0gSAJi56g10dtl269RgOruMeOrNw/jo+xPYsCweKxdFo6FFj06ZdtTvPRZEL+DS09Nx+eWXC483bNgAAFi1ahV27tyJZcuWoa6uDo8//jiqq6sxe/Zs7N+/Hz4+/U9iOxGZTLZND0Lsj7LgA+XAD8qCH5SFbY6/2X2CJ/GOt9Hc1ncgX1upVQ5466vunqfP7U7DykXRuPjuXTCZGS6JLcWOh68ck/aOlOiXUBcsWADGWJ+fnTt3CuusX78e5eXl6OzsREpKCpKSksRrsB1YumgT8VEWfKAc+EFZ8IOyGJ60V1cKY8tJJN3jvykdbC97dB1dwu+GnpkdTD09J34+fnpsGzsCop+BI4QQQgixly3rL8eW9eeu9A00dMhgzGZm1anBZZQDDY8FrsaB48X27duxfft2mEwmFBUV4eDBg3YdyDcyMhJVVVU0kC8Hg3R6e3vDxcWFBvIVeSDf+vp6KJVKGsiXg4F8CwoKIJVK6TuCg4F8LevSd0Q1Ghoa+mxvW78jDmTVY/eP1eg0miGTAv1MjdovqRSwjEomAfDWfTEA7DOQ78KFC4ccB44KuEHYOpjeaOXk5CAmJsZu709sR1nwgXLgB2XBD8rCPs6NL2eC0kFq031z4YFu+PyZvvO3jwUayHcCaWsb2QjTZOxRFnygHPhBWfCDsrCP83ux3v6P/fgxe/B73EpON9m5VUMTvRMDAZydncVuAulBWfCBcuAHZcEPymJ8vP7wldi0dv6g65hsmQbCzugS6iDG6xJqZ2enKFN4kb4oCz5QDvygLPhBWYhj9zf52PTGL1ZTd/l7qoUpvsaarbUHnYHjgOUmUCI+yoIPlAM/KAt+UBbiWJ4chYJeMz6svtzfbsXbcNA9cIQQQgghQ7DcK2fp9So2OgPHgcDAQLGbQHpQFnygHPhBWfCDsuADLzlQAccBqZRi4AVlwQfKgR+UBT8oCz7wkgMfrZjkLAMfEvFRFnygHPhBWfCDsuADLznQPXD96D0TAwCkp6fbdSYGk8mE/Px8momBg1HWOzo6UF9fTzMxiDzKemNjIzIzM2kmBg5mYmhsbERKSgp9R3AwE0NTUxNSUlLoO2KUMzGM9jvCYDCgvLzcbt8RljYNhYYRGcR4DSPS0dEBR0dHu70/sR1lwQfKgR+UBT8oCz7YOwcaRmQCsfwvhIiPsuAD5cAPyoIflAUfeMmBCjgOtLa2it0E0oOy4APlwA/Kgh+UBR94yYEKOA7QKXF+UBZ8oBz4QVnwg7LgAy850D1wgxive+C6urrg4OBgt/cntqMs+EA58IOy4AdlwQd750D3wE0glp45RHyUBR8oB35QFvygLPjASw40jMggLCcnW1pa7Po5bW1tdv8MYhvKgg+UAz8oC35QFnywdw6W9x7qAikVcIOw3KjIy7QZhBBCCJkcWltb4erqOuDzdA/cIMxmM6qqquDi4gKJRGKXz2hpaUFgYCAqKyvtep8dGRplwQfKgR+UBT8oCz6MRw6MMbS2tsLf33/QabvoDNwgpFIpAgICxuWzNBoNHZScoCz4QDnwg7LgB2XBB3vnMNiZNwvqxEAIIYQQMsFQAUcIIYQQMsFQAScypVKJJ554AkqlUuymTHqUBR8oB35QFvygLPjAUw7UiYEQQgghZIKhM3CEEEIIIRMMFXCEEEIIIRMMFXCEEEIIIRMMFXCEEEIIIRMMFXAi2759O4KDg6FSqZCUlITU1FSxmzSh/fDDD7j22mvh7+8PiUSCTz75xOp5xhgef/xx+Pn5wdHREcnJyThx4oTVOg0NDVixYgU0Gg20Wi1uv/126HQ6q3Wys7Nx6aWXQqVSITAwEM8++6y9/7QJZfPmzUhMTISLiwu8vb2xZMkSFBYWWq2j1+txzz33wMPDA87OzrjxxhtRU1NjtU5FRQWuvvpqODk5wdvbG3/6059gNBqt1jl06BDi4uKgVCoRFhaGnTt32vvPm1BeeuklxMbGCgOPzps3D/v27ROepxzE8Y9//AMSiQT333+/sIyyGB8bN26ERCKx+omMjBSenzA5MCKaPXv2MIVCwXbs2MFyc3PZHXfcwbRaLaupqRG7aRPWl19+yR599FH28ccfMwBs7969Vs//4x//YK6uruyTTz5hx44dY9dddx0LCQlhHR0dwjpXXnklmzVrFjty5Aj78ccfWVhYGFu+fLnwfHNzM/Px8WErVqxgOTk5bPfu3czR0ZG98sor4/Vncm/RokXsjTfeYDk5OSwrK4tdddVVLCgoiOl0OmGdO++8kwUGBrKDBw+y9PR0dtFFF7H58+cLzxuNRhYTE8OSk5PZ0aNH2Zdffsk8PT3ZI488Iqxz8uRJ5uTkxDZs2MDy8vLYCy+8wGQyGdu/f/+4/r08++yzz9gXX3zBioqKWGFhIfvLX/7CHBwcWE5ODmOMchBDamoqCw4OZrGxsey+++4TllMW4+OJJ55g0dHR7MyZM8JPXV2d8PxEyYEKOBHNnTuX3XPPPcJjk8nE/P392ebNm0Vs1YXj/ALObDYzX19f9txzzwnLmpqamFKpZLt372aMMZaXl8cAsLS0NGGdffv2MYlEwk6fPs0YY+zFF19kbm5urLOzU1jnz3/+M4uIiLDzXzRx1dbWMgDs+++/Z4x1b3cHBwf2wQcfCOvk5+czAOzw4cOMse5iXCqVsurqamGdl156iWk0GmHbP/TQQyw6Otrqs5YtW8YWLVpk7z9pQnNzc2OvvfYa5SCC1tZWNn36dHbgwAH2q1/9SijgKIvx88QTT7BZs2b1+9xEyoEuoYrEYDAgIyMDycnJwjKpVIrk5GQcPnxYxJZduEpLS1FdXW21zV1dXZGUlCRs88OHD0Or1SIhIUFYJzk5GVKpFCkpKcI6l112GRQKhbDOokWLUFhYiMbGxnH6ayaW5uZmAIC7uzsAICMjA11dXVZZREZGIigoyCqLmTNnwsfHR1hn0aJFaGlpQW5urrBO7/ewrEPHUP9MJhP27NmDtrY2zJs3j3IQwT333IOrr766z/aiLMbXiRMn4O/vj2nTpmHFihWoqKgAMLFyoAJOJPX19TCZTFY7AAD4+PigurpapFZd2CzbdbBtXl1dDW9vb6vn5XI53N3drdbp7z16fwY5x2w24/7778fFF1+MmJgYAN3bSaFQQKvVWq17fhZDbeeB1mlpaUFHR4c9/pwJ6fjx43B2doZSqcSdd96JvXv3YsaMGZTDONuzZw8yMzOxefPmPs9RFuMnKSkJO3fuxP79+/HSSy+htLQUl156KVpbWydUDvIxeRdCCBnAPffcg5ycHPz0009iN2XSioiIQFZWFpqbm/Hhhx9i1apV+P7778Vu1qRSWVmJ++67DwcOHIBKpRK7OZPa4sWLhd9jY2ORlJSEqVOn4v3334ejo6OILRseOgMnEk9PT8hksj49W2pqauDr6ytSqy5slu062Db39fVFbW2t1fNGoxENDQ1W6/T3Hr0/g3Rbv349Pv/8c3z33XcICAgQlvv6+sJgMKCpqclq/fOzGGo7D7SORqOZUF/E9qZQKBAWFob4+Hhs3rwZs2bNwr///W/KYRxlZGSgtrYWcXFxkMvlkMvl+P7777Ft2zbI5XL4+PhQFiLRarUIDw9HcXHxhDomqIATiUKhQHx8PA4ePCgsM5vNOHjwIObNmydiyy5cISEh8PX1tdrmLS0tSElJEbb5vHnz0NTUhIyMDGGdb7/9FmazGUlJScI6P/zwA7q6uoR1Dhw4gIiICLi5uY3TX8M3xhjWr1+PvXv34ttvv0VISIjV8/Hx8XBwcLDKorCwEBUVFVZZHD9+3KqgPnDgADQaDWbMmCGs0/s9LOvQMTQ4s9mMzs5OymEcLVy4EMePH0dWVpbwk5CQgBUrVgi/Uxbi0Ol0KCkpgZ+f38Q6JsasOwQZtj179jClUsl27tzJ8vLy2Lp165hWq7Xq2UKGp7W1lR09epQdPXqUAWBbtmxhR48eZeXl5Yyx7mFEtFot+/TTT1l2dja7/vrr+x1GZM6cOSwlJYX99NNPbPr06VbDiDQ1NTEfHx+2cuVKlpOTw/bs2cOcnJxoGJFe7rrrLubq6soOHTpk1VW/vb1dWOfOO+9kQUFB7Ntvv2Xp6els3rx5bN68ecLzlq76V1xxBcvKymL79+9nXl5e/XbV/9Of/sTy8/PZ9u3baciE8zz88MPs+++/Z6WlpSw7O5s9/PDDTCKRsK+//poxRjmIqXcvVMYoi/HywAMPsEOHDrHS0lL2888/s+TkZObp6clqa2sZYxMnByrgRPbCCy+woKAgplAo2Ny5c9mRI0fEbtKE9t133zEAfX5WrVrFGOseSuSxxx5jPj4+TKlUsoULF7LCwkKr9zh79ixbvnw5c3Z2ZhqNhq1Zs4a1trZarXPs2DF2ySWXMKVSyaZMmcL+8Y9/jNefOCH0lwEA9sYbbwjrdHR0sLvvvpu5ubkxJycntnTpUnbmzBmr9ykrK2OLFy9mjo6OzNPTkz3wwAOsq6vLap3vvvuOzZ49mykUCjZt2jSrzyCMrV27lk2dOpUpFArm5eXFFi5cKBRvjFEOYjq/gKMsxseyZcuYn58fUygUbMqUKWzZsmWsuLhYeH6i5CBhjLGxO59HCCGEEELsje6BI4QQQgiZYKiAI4QQQgiZYKiAI4QQQgiZYKiAI4QQQgiZYKiAI4QQQgiZYKiAI4QQQgiZYKiAI4QQQgiZYKiAI4QQQgiZYKiAI4QQEUkkEnzyySdiN4MQMsFQAUcImbRWr14NiUTS5+fKK68Uu2mEEDIoudgNIIQQMV155ZV44403rJYplUqRWkMIIbahM3CEkElNqVTC19fX6sfNzQ1A9+XNl156CYsXL4ajoyOmTZuGDz/80Or1x48fx69//Ws4OjrCw8MD69atg06ns1pnx44diI6OhlKphJ+fH9avX2/1fH19PZYuXQonJydMnz4dn332mfBcY2MjVqxYAS8vLzg6OmL69Ol9Ck5CyORDBRwhhAzisccew4033ohjx45hxYoVuOWWW5Cfnw8AaGtrw6JFi+Dm5oa0tDR88MEH+Oabb6wKtJdeegn33HMP1q1bh+PHj+Ozzz5DWFiY1Wds2rQJN998M7Kzs3HVVVdhxYoVaGhoED4/Ly8P+/btQ35+Pl566SV4enqO3wYghPCJEULIJLVq1Somk8mYWq22+nn66acZY4wBYHfeeafVa5KSkthdd93FGGPsv//9L3Nzc2M6nU54/osvvmBSqZRVV1czxhjz9/dnjz766IBtAMD++te/Co91Oh0DwPbt28cYY+zaa69la9asGZs/mBBywaB74Aghk9rll1+Ol156yWqZu7u78Pu8efOsnps3bx6ysrIAAPn5+Zg1axbUarXw/MUXXwyz2YzCwkJIJBJUVVVh4cKFg7YhNjZW+F2tVkOj0aC2thYAcNddd+HGG29EZmYmrrjiCixZsgTz588f0d9KCLlwUAFHCJnU1Gp1n0uaY8XR0dGm9RwcHKweSyQSmM1mAMDixYtRXl6OL7/8EgcOHMDChQtxzz334Pnnnx/z9hJCJg66B44QQgZx5MiRPo+joqIAAFFRUTh27Bja2tqE53/++WdIpVJERETAxcUFwcHBOHjw4Kja4OXlhVWrVuGdd97B1q1b8d///ndU70cImfjoDBwhZFLr7OxEdXW11TK5XC50FPjggw+QkJCASy65BO+++y5SU1Px+uuvAwBWrFiBJ554AqtWrcLGjRtRV1eHe++9FytXroSPjw8AYOPGjbjzzjvh7e2NxYsXo7W1FT///DPuvfdem9r3+OOPIz4+HtHR0ejs7MTnn38uFJCEkMmLCjhCyKS2f/9++Pn5WS2LiIhAQUEBgO4eonv27MHdd98NPz8/7N69GzNmzAAAODk54auvvsJ9992HxMREODk54cYbb8SWLVuE91q1ahX0ej3+9a9/4cEHH4Snpyduuukmm9unUCjwyCOPoKysDI6Ojrj00kuxZ8+eMfjLCSETmYQxxsRuBCGE8EgikWDv3r1YsmSJ2E0hhBArdA8cIYQQQsgEQwUcIYQQQsgEQ/fAEULIAOgOE0IIr+gMHCGEEELIBEMFHCGEEELIBEMFHCGEEELIBEMFHCGEEELIBEMFHCGEEELIBEMFHCGEEELIBEMFHCGEEELIBEMFHCGEEELIBPP/mgO0HCEveEgAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(7, 4))\n", "\n", "plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='#25599c', markersize=1)\n", "\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "plt.title('Training Loss Over Epochs')\n", "plt.yscale('log')\n", "\n", "plt.legend()\n", "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "5ef7d214-6efa-4124-a587-b47e98cac7d6", "metadata": {}, "source": [ "The following plot shows the trained neural network on the entire domain, approximating the solution, $u$, of the equation." ] }, { "cell_type": "code", "execution_count": 8, "id": "7798eb68-d6b5-47ec-b845-cf3d60dccf23", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAApIAAAGGCAYAAADIJ2F2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAondJREFUeJztvXl8VNX9//+amWwghoCQBDTKZkVUxIJEUOsCGsRS8UOrKJWlAm5ghVoFFQOi4EItilTqguJXKe7+VBSNKLVqKopSN8SCKJSSKFKILGaZub8/hgyZe9+Tec+55869d/J+9jGPmss573PmZjLzmvd2AoZhGBAEQRAEQRCEFAm6vQFBEARBEATBn4iQFARBEARBEJQQISkIgiAIgiAoIUJSEARBEARBUEKEpCAIgiAIgqCECElBEARBEARBCRGSgiAIgiAIghIiJAVBEARBEAQlREgKgiAIgiAISoiQFAQXOf3003H66adrtfnNN98gEAjg0Ucf1WpXNytWrECfPn2Ql5eHQCCAnTt3ur0lIY2sWrUKgUAAq1atcnsrgiDYQISkIKTAp59+il//+tc44ogjkJeXh0MPPRRnnXUWFixYkPa9LF26FPPnz0/7ujr44YcfcMEFF6BVq1ZYuHAh/t//+3846KCDyLGPPvooAoFA3KOwsBBnnHEGXn311TTv3Ds0CrFEj2XLlrm9RQDAX/7yF89/qREEQZ0stzcgCH7hvffewxlnnIHDDz8cEyZMQHFxMbZs2YJ//vOfuOeeezB58uS07mfp0qX47LPPcM0118RdP+KII7Bv3z5kZ2endT+p8MEHH+DHH3/E7NmzMXjwYNacW265BV27doVhGKiursajjz6KoUOH4qWXXsIvf/lLh3fsXa6++mqceOKJlusDBgxwYTdW/vKXv6BDhw4YO3Zs3PVf/OIX2LdvH3JyctzZmCAIWhAhKQhMbrvtNrRt2xYffPABCgoK4v7tu+++c2dTBIFAAHl5eW5vo1ka75f5PjbHOeecg379+sV+vvTSS1FUVIS//e1v2oTknj17EnpGdfPTTz8hJycHwaC9wNCpp56KX//615p2lT6CwaDnX6eCICRHQtuCwGTjxo045phjSPFTWFgY93NDQwNmz56N7t27Izc3F126dMENN9yA2traZtdoDON+8803cdfN+WSnn346li9fjm+//TYWyuzSpQuAxDmSb775Jk499VQcdNBBKCgowHnnnYd169bFjZk5cyYCgQA2bNiAsWPHoqCgAG3btsW4ceOwd+/epPcIAJ5++mn07dsXrVq1QocOHfDb3/4WW7dujf376aefjjFjxgAATjzxRAQCAYu3ikNBQQFatWqFrKwD34cT5d1R92Ts2LFo06YNNm7ciKFDh+Lggw/GqFGjAAD79u3D1VdfjQ4dOuDggw/Gr371K2zduhWBQAAzZ86Ms71161b87ne/Q1FREXJzc3HMMcdg8eLFcWMa97Vs2TLcdNNNOPTQQ9G6dWvU1NSgvr4es2bNwpFHHom8vDwccsghOOWUU1BRUZHyPUlEbW0tpkyZgo4dO8aez3/+8x/L8xk7dmzsddSUxtdFUx555BGceeaZKCwsRG5uLnr16oX7778/bkyXLl3w+eef4+9//3vsddqYE5zod5Xs9dO4zzZt2mDr1q0YPnw42rRpg44dO+Laa69FOBxWvk+CIKSOeCQFgckRRxyByspKfPbZZzj22GObHTt+/HgsWbIEv/71r/GHP/wB77//PubOnYt169bh+eeft72XG2+8Ebt27cJ//vMf/PnPfwYAtGnTJuH4N954A+eccw66deuGmTNnYt++fViwYAFOPvlkfPTRRxbxcMEFF6Br166YO3cuPvroIzz00EMoLCzEHXfc0ey+Hn30UYwbNw4nnngi5s6di+rqatxzzz1499138fHHH6OgoAA33ngjjjrqKDzwwAOxcHX37t2TPuddu3Zh+/btMAwD3333HRYsWIDdu3fjt7/9bfIbloCGhgaUlZXhlFNOwbx589C6dWsAUaHy1FNP4ZJLLsFJJ52Ev//97zj33HMt86urq3HSSSchEAhg0qRJ6NixI1599VVceumlqKmpsaQdzJ49Gzk5Obj22mtRW1uLnJwczJw5E3PnzsX48ePRv39/1NTU4MMPP8RHH32Es846K+lz+PHHH7F9+3bL9UMOOSQm/saPH4/HH38cF198MQYOHIg333yTfD6pcP/99+OYY47Br371K2RlZeGll17ClVdeiUgkgquuugoAMH/+fEyePBlt2rTBjTfeCAAoKipKaJPz+mkkHA6jrKwMpaWlmDdvHt544w386U9/Qvfu3XHFFVfYem6CIKSAIQgCi9dff90IhUJGKBQyBgwYYFx33XXGa6+9ZtTV1cWNW7t2rQHAGD9+fNz1a6+91gBgvPnmm7Frp512mnHaaafFfn7kkUcMAMamTZvi5r711lsGAOOtt96KXTv33HONI444wrLPTZs2GQCMRx55JHatT58+RmFhofHDDz/Erv3rX/8ygsGgMXr06Ni18vJyA4Dxu9/9Ls7m+eefbxxyyCGJbo1hGIZRV1dnFBYWGscee6yxb9++2PWXX37ZAGDcfPPNluf5wQcfNGuz6VjzIzc313j00UfjxlL3KdE9GTNmjAHAmDZtWtzYNWvWGACMa665Ju762LFjDQBGeXl57Nqll15qdOrUydi+fXvc2JEjRxpt27Y19u7dG7evbt26xa41cvzxxxvnnntu0vtgptFmose2bdsMwzjwerzyyivj5l988cWW5zNmzBjyNdX4umiK+XkYhmGUlZUZ3bp1i7t2zDHHxL3Gzftv/F2l8vpp/N3dcsstcTZPOOEEo2/fvpa1BEFwDgltCwKTs846C5WVlfjVr36Ff/3rX7jzzjtRVlaGQw89FC+++GJs3CuvvAIAmDp1atz8P/zhDwCA5cuXp2/TALZt24a1a9di7NixaN++fex67969cdZZZ8X225TLL7887udTTz0VP/zwA2pqahKu8+GHH+K7777DlVdeGZf7du6556Jnz562n/fChQtRUVGBiooKPP744zjjjDMwfvx4PPfcc7bsmr1XK1asAABceeWVcdfNxVSGYeDZZ5/FsGHDYBgGtm/fHnuUlZVh165d+Oijj+LmjBkzBq1atYq7VlBQgM8//xz//ve/lfZ/8803x+5L00fj77rx93v11VfHzTN7S1Ol6fNo9Bafdtpp+Prrr7Fr166U7am8fqjX6ddff53y2oIgqCOhbUFIgRNPPBHPPfcc6urq8K9//QvPP/88/vznP+PXv/411q5di169euHbb79FMBhEjx494uYWFxejoKAA3377bVr33LjeUUcdZfm3o48+Gq+99pqlyOTwww+PG9euXTsAwP/+9z/k5+envE7Pnj3xzjvvqD2B/fTv3z+u2Oaiiy7CCSecgEmTJuGXv/ylUvVvVlYWDjvssLhrjb+/rl27xl03/z6///577Ny5Ew888AAeeOAB0r65CMtsE4hWo5933nn42c9+hmOPPRZDhgzBJZdcgt69e7Oew3HHHdds5Xvj8zGnD1C/p1R49913UV5ejsrKSkv+7K5du9C2bduU7KX6+snLy0PHjh3jrrVr1w7/+9//UlpXEAR7iEdSEBTIycnBiSeeiDlz5uD+++9HfX09nn766bgx5uIEDonmpLuAIBQKkdcNw0jrPpojGAzijDPOwLZt22LevFTvX25urnLVdCQSAQD89re/JT2CFRUVOPnkk+PmmL2RQLQNzsaNG7F48WIce+yxeOihh/Dzn/8cDz30kNK+7MC9fxs3bsSgQYOwfft23H333Vi+fDkqKiowZcoUAAfujZMkeo0KgpBexCMpCDZp9JJt27YNQLQoJxKJ4N///jeOPvro2Ljq6mrs3LkTRxxxREJbjZ4/8ykvlBeTK1Qb11u/fr3l37788kt06NBBS8ubpuuceeaZcf+2fv36Zp+3Kg0NDQCA3bt3A0jt/iWi8fe3adMmHHnkkbHrGzZsiBvXWAEdDofZvTAT0b59e4wbNw7jxo3D7t278Ytf/AIzZ87E+PHjbdkFDjyfjRs3xnn7qNdDu3btyBOGzPfvpZdeQm1tLV588cU47/Vbb71lmavyOk3X60cQBPuIR1IQmLz11lukR64xB63xQ3ro0KEAYDl15u677waAZqtlG8OPb7/9duxaOBwmQ6cHHXQQKxetU6dO6NOnD5YsWRInEj777DO8/vrrsf3apV+/figsLMSiRYvi2hy9+uqrWLdune0qYTP19fV4/fXXkZOTExPsRxxxBEKhUNz9A6JNsbmUlZWRc8ynF4VCIYwYMQLPPvssPvvsM4ud77//nrXeDz/8EPdzmzZt0KNHj6Storicc845AIB777037jp1KlL37t2xa9cufPLJJ7Fr27Zts3QaaPQGNv172LVrFx555BGLzYMOOoh1/GW6Xz+CIOhBPJKCwGTy5MnYu3cvzj//fPTs2RN1dXV477338OSTT6JLly4YN24cAOD444/HmDFj8MADD2Dnzp047bTTsHr1aixZsgTDhw/HGWeckXCNY445BieddBKmT5+OHTt2oH379li2bFnM89aUvn374sknn8TUqVNx4oknok2bNhg2bBhp96677sI555yDAQMG4NJLL421/2nbtq2lL6Iq2dnZuOOOOzBu3DicdtppuOiii2LtW7p06RILe6ry6quv4ssvvwQQzT1cunQp/v3vf2PatGmxvM22bdviN7/5DRYsWIBAIIDu3bvj5ZdfTqlhfN++fTFixAjMnz8fP/zwQ6z9z1dffQUg3sN2++2346233kJpaSkmTJiAXr16YceOHfjoo4/wxhtvYMeOHUnX69WrF04//XT07dsX7du3x4cffohnnnkGkyZNYu33H//4B3766SfL9d69e6N3797o06cPLrroIvzlL3/Brl27MHDgQKxcudLiYQWAkSNH4vrrr8f555+Pq6++Gnv37sX999+Pn/3sZ3GFQ2effTZycnIwbNgwXHbZZdi9ezcefPBBFBYWxjzzTe/n/fffj1tvvRU9evRAYWGhxeMIOP/6EQTBIVytGRcEH/Hqq68av/vd74yePXsabdq0MXJycowePXoYkydPNqqrq+PG1tfXG7NmzTK6du1qZGdnGyUlJcb06dONn376KW6cuf2PYRjGxo0bjcGDBxu5ublGUVGRccMNNxgVFRWWtja7d+82Lr74YqOgoMAAEGvbQrW6MQzDeOONN4yTTz7ZaNWqlZGfn28MGzbM+OKLL+LGNLZ5+f777+OuJ2pLRPHkk08aJ5xwgpGbm2u0b9/eGDVqlPGf//yHtKfa/icvL8/o06ePcf/99xuRSCRu/Pfff2+MGDHCaN26tdGuXTvjsssuMz777DOy/c9BBx1Errlnzx7jqquuMtq3b2+0adPGGD58uLF+/XoDgHH77bfHja2urjauuuoqo6SkxMjOzjaKi4uNQYMGGQ888EBsTGOrm6efftqy1q233mr079/fKCgoMFq1amX07NnTuO222yxtpcwka//TtK3Pvn37jKuvvto45JBDjIMOOsgYNmyYsWXLFss4w4i2uTr22GONnJwc46ijjjIef/xxsv3Piy++aPTu3dvIy8szunTpYtxxxx3G4sWLLa+Tqqoq49xzzzUOPvhgA0Ds9Z6oVRPn9ZPod0ftUxAEZwkYhoey5wVBEDzK2rVrccIJJ+Dxxx+PnYDjdwKBAMrLy7V5pQVBaHlIjqQgCIKJffv2Wa7Nnz8fwWAQv/jFL1zYkSAIgjeRHElBEAQTd955J9asWYMzzjgDWVlZePXVV/Hqq69i4sSJKCkpcXt7giAInkGEpCAIgomBAweioqICs2fPxu7du3H44Ydj5syZsfOiBUEQhCi+Cm2//fbbGDZsGDp37oxAIIAXXngh6ZxVq1bh5z//OXJzc9GjRw88+uijljELFy5Ely5dkJeXh9LSUqxevVr/5gVB8A1nnXUW3nnnHezYsQN1dXXYsGEDysvLkZWVWd+9DcOQ/EhBEGzhKyG5Z88eHH/88Vi4cCFr/KZNm3DuuefijDPOwNq1a3HNNddg/PjxeO2112JjGtunlJeX46OPPsLxxx+PsrKylNqFCIIgCIIgOEGqzq758+fjqKOOQqtWrVBSUoIpU6aQLcJ04duq7UAggOeffx7Dhw9POOb666/H8uXL45oFjxw5Ejt37sSKFSsAAKWlpTjxxBNx3333AYge7VVSUoLJkydj2rRpjj4HQRAEQRCERDz55JMYPXo0Fi1ahNLSUsyfPx9PP/001q9fj8LCQsv4pUuX4ne/+x0WL16MgQMH4quvvsLYsWMxcuTI2KEYusmsOI2JyspKy9FlZWVluOaaawAAdXV1WLNmDaZPnx7792AwiMGDB6OyspK9TiQSwX//+18cfPDBSucrC4IgCILAxzAM/Pjjj+jcuTOCwfQGV3/66SfU1dUpzc3JyUFeXh57/N13340JEybEDrxYtGgRli9fjsWLF5POrvfeew8nn3wyLr74YgBAly5dcNFFF+H9999X2i+HjBaSVVVVKCoqirtWVFSEmpoa7Nu3D//73/8QDofJMY0naFDU1tbGHeG1detW9OrVS+/mBUEQBEFoli1btuCwww5L23o//fQTilu1xS6oCcni4mJs2rSJJSZVnF0DBw7E448/jtWrV6N///74+uuv8corr+CSSy5R2i+HjBaSTjF37lzMmjXLcn3L5ieRn9868UQjom8TOm1RBHyVPsvD6XsmNI8XXlNe2AMASOBCcANfJrLR1NTsRcnhF+Lggw9O67p1dXXYhTrMw0C0SlFC7UMDrq16D9u3b48d6woAubm5yM3NtYzfvn17ys6uiy++GNu3b8cpp5wCwzDQ0NCAyy+/HDfccENKe02FjBaSxcXFqK6ujrtWXV2N/Px8tGrVCqFQCKFQiBxTXFyc0O706dMxderU2M81NTUoKSnBQfk5OCjf+mJoJMD89Aj4qwZKG55JC/Dqm60bQtgrwksFj7ycOPg0VV2wgWfe73TiwsvYrft4UDAbrQKpSaigEQAisPSi1Xm61KpVqzBnzhz85S9/QWlpKTZs2IDf//73mD17NmbMmKFlDTMZLSQHDBiAV155Je5aRUUFBgwYACCaq9C3b1+sXLkyVrQTiUSwcuVKTJo0KaHdRN8eGiL1aIgkdndTQjKg+EHNFaVegP0cNb4JOX1/nBb75Jujn0WdZljCy0fazIB4y93EjS/vbnx5cFx06TTv8b/fYAgIpvh8gwaASDQcb/ZIUnTo0CFlZ9eMGTNwySWXYPz48QCA4447Dnv27MHEiRNx4403OpJP6ishuXv3bmzYsCH286ZNm7B27Vq0b98ehx9+OKZPn46tW7fiscceAwBcfvnluO+++3Ddddfhd7/7Hd5880089dRTWL58eczG1KlTMWbMGPTr1w/9+/fH/PnzsWfPnlhiayo0RGrREAkBoMUTKW6IPxZ1cekNoWF5nkZYr32NgkpVcAYCzr7LBQz/fFEQ7GN4/VPTg+j8sugVIe/0e7jT4lWrUE1myuW3yEAwgGCKz7fxfT0/Pz9OSCZCxdm1d+9ei1gMhaK6xKnfv6+E5Icffogzzjgj9nNjeHnMmDF49NFHsW3bNmzevDn27127dsXy5csxZcoU3HPPPTjssMPw0EMPoaysLDbmwgsvxPfff4+bb74ZVVVV6NOnD1asWGHJSeDQYNShwYje0oBBCEniRUe+cTB+1+wwucNeLMc9f9T+GcLUzhsy6zkRe/CCwLW1png8bWNIHm6z6HyN+Ul8c/+e0y1odQtXnULF66H/UAgIpbjFkMLtSebsGj16NA499FDMnTsXADBs2DDcfffdOOGEE2Kh7RkzZmDYsGExQakbXwnJ008/vdkXKnVqzemnn46PP/64WbuTJk1qNpTNJWyEETYaEv475WViey4tY9TDxap/oDrfdNhCiSMayQ8nnheUda/ZoXlCXCreMzcEuqHZcyy0bKjXcCYKbc77gxuil/Me4rRwtfOZkUyUup1XHFTwSAYVIk3JnF2bN2+O80DedNNNCAQCuOmmm7B161Z07NgRw4YNw2233Zby2lx825DcS9TU1KBt27bYUHUfDs5vBYDvfdTpWXTaS6kcBrbjHXRB9GoVlwq2da9J7yP93kc/5fUKzuB3r7cXXsNeuYdu34uamj1oVzAcu3btYoWJ9a0b/bxfkn8GWqdYbLPXaMCYmrfSvmen8cYrUhAEQRAEQfAdvgpte536SD3qI9FbSlVzkR5JjZ5LrpeJU8hh51uvda/MMDOZD6myHr2mcpiZ8oq6EMYOaAyPBZwOaVGhc0dX9DdeKZRzGm76hNveroR4wRvoQnoA/ffszXB9ugiG0hPa9gMiJDVSGw4iJxz9gwsSVb3UtVDA+qYQDBBvtqapXFFKwfnQcrrnJV3Brq+AhS0uOSkD7D9+NXHJDt/7KQfTo7LRSx9ETXGjatgrIVIKzqvHnZY91N9gml9TbvzePChe3S62CgaAVDvpBDMvRRiACEmt7GsIItQQfWWFSCFpnUOLy+TjuKKUghKqljdldpGONz1/dop59OajqvU7dMVLqbFKMuBGjzxWcYF/cFro6SyuckOgc8W34/fRvJ7jLXzSL2apvxvHv4gkEa9uF24FQwEEU2wkmaoH0y+IkNTIvkgQof0eyewgzyMZJN4AaBEaMP1MCUnei5TrLbXaVxeqZrhtjzjilR3S53psGSJIVahy33z5Hwxq9sk1NYZdnPZIUr9LL4hEnSLC6Sp6vX0YbezDA4LZ6Z6U6RazQBo8tg4LuWTi1W2PZCgYfaQ0x5mtuI4ISY3UhQOoDUffkBoi1jcmSqzRgjP5XK4ApeDYp7yiDcSaHAFKf2mjhDAltDkwvZYeEKpcT516hbz6m6vOD1OdDdvpNjLazKewD06bF4fFn0aPhs5baK/Ni9o9c0MIqwpCN3JDOR5bOwLXjfSDpvcx4rZHMqjgkfRoao1dvJskIwiCIAiCIHga8UhqZF9DEMGGxMU22cS3l/qIWo6kqicTAPlV0vxNiZvjCUUvJfUcCScuCedLIDsMzxjD9fLxKuuZtlT7Z9pw1akfF6mvuTnpfVSylMC+LQ+M+zmFqr9ex/P22J0ZXPAiOpxukG7PqNbTgRx/Ps7lsbp9rGUglHqxTYamSIqQ1MlPTYRkFiH0GgjRmEW8EGkhGX+REqBccUlhFo5cUUr/HSUXpbRo5K3JOZaqJYhSW/Y15hiqilc3RKOdvENfCwaNJz3Z2gdZWJZekav76L1MK7ZxOsxPrqkQ+ne92CYIhdB2ZiJCUiO1YSC0P0eynlAylNDjC8LmfwbovExK0FJkM17hlLjkeC7JP3fSK2qFztVsfr0omS9KKZwXqlRLDjVvlK0CDU6+oua8NGWh7au8Oh5Oi1wnBa3TYpZ8Pg67otwotnG7ejzi8pGuUSGZ4hxntuI6IiQ1sjcMNB61TQkzWlzywt1mcckVdVxBGzFV7HJFF6cAR1WAAkADUUlssc/89FMNp3O/dKZblEbXND95qi+pdR7XiWj+MLITwuSsyf3w4+yD++HtdIjUDQ+S261ZbOPg/mlPuMPPx3Ch56VHBK1lD5rklNtV28FgAEHOm37TORlabCNCUiN7GoBIM0Iyh7gWJv4WqDY+PKHH88IROpXl8aRD7smv2akA5+yDK7qojwqOCOW+V6RblOpfk7oZyT9guUJVZ9Uz5wNRt+dJZ4jU8XCoj4UeAGf370ojdhcaelPRAzfaHZhhRmWS4XaOpFL7Hw/cfifIVE+rIAiCIAiC4DDikdRIbRho7MVNeRrriS9QXM+leS4VEucU6QBAhPhWqjMczfFIcuYlmmv2Zqp6MhPByvGk5pHe5fiLZs8ytV50HLUv68WwyR4ZvmfaovZmnWe9RnvV9Xk3KXR6PFU9i1o9npq8NAfwj8fQ6XOczaFsp4s03DiX2pVwPQMnvaJue1eDIQltNyJCUiM/hQPA/mKbOuJvOIdQGtSHMCf0TH94J58XtU+JGX2hc7M4psKvdIsgrjhu/udE81TzE1UFKMAToXZyPC3rac/BNNn3jCiliL/bTotSEkVBqL0ARKMwpQusNH6I+0iUkmu2AKHKxXIvHBSzbgtlpWIbb/yatCNCUiN764FwffS/s4mzkKgP1zDxCcLxXHK9m3wPlXkeVaSTfF7UPsdjyPN4ckQomVOqsZLbjjjjiD9bOZ6mn7lCTNUzyvGKJkKnKOXcM52ilLumF0QptQ8RpXpoCUKV3AMzrzSdgtZt8ax0so3GY2i9hAhJjfwUOXC0MulpJP4W64gXIl2UYxZP1BjrNdVxdgqDrMUwXAHHLcAxh3O5IXfemlbhqCZAaVvUvqzX7Ig/y7zkQ6K2NBYe6RSl3C9D1j3ou4eA9fXfUkWpW/vg7MHJkDigV8R5QZyxW1+53LORwu09BUJAMMXDszNTRoqQ1MreBqChufY/xDVqHCU4wyGzeCLGEO/4qkKSK0qpa+bnRM/jeRE5Xj7VeYnnJn/jtrN/1X1RqNrnhuHN94IUYsQ8jjfQTr4ohXmvdkQp5zXQYkUp4IkQvvPtbfyTZ6qMKxXs6jjdpzIVxCN5ABGSGqmrDQFZ0a8okWzrm0SEqP3n5khaP3iIMcRFruA0eyC5QpISwpy9csVNDnHRLLL050gm93hyBSglOHm2kk7bbz/5HuwIHidD81wBSsHx2OosYqL3oG7fOk+fKAW4wlRNlALeDZ1r7deovQDKTPpFnBfC5FySpRG4nW4YVPBI+ku288nU5yUIgiAIgiA4jHgkNVJf18QjSbiZIlnWb39hwnNJeRhYHkliHlX0Q3lbOKE2ft6kaQ/MIiCdhUHcVkiqxTb8VkXJ56l6Mqm9Od3w3E6lNcdjSMH2smraF3dvdrybrAp8Re9mdK7avfaCdxPwRri7JRyl6PQxirZI5i112ZsaCgQQSjG0HeKGxXyGCEmN1NUFYWRF/zBJIcm9xgiLcz/Q2QU4IfOY5GIzOk5fmJwbOucIvTDTPifXUbW/ZXRu8t+bvTB88gp2Cm5ontOnkm/fbNs6xlaFuWWMmgBNtDczOqvhKbwqSqNzk89zo5UThc4wvOPiL91HKToevrdD8/fC/apthfY/HtbtdhAhqZGG+iACsYoaQgxGrO7BCKF4ON7MLEpsMoUkx7PI9T5SOYzK3k22tyv+Zzq30npNXXByPZ7Jr/ELfqzXKHjeU15urno+p0bBwxbQycfYqTCnxzlXeEShu3m9dV9qvyNqTW+LUssulPaQCF8JVfN6DntdbZFE5LreR1KlIbkU2wjJqK0LIRJKHNrOIkLb1EcPx3NJ2ifD5DwRYRZe1IdTNlUYRPWzNOnliEbvJmAVf6oh8URzzddID66ix9NOmJxC1XvKsRW1Z66EpuyrCR7d7X9YQs/h0Dyn8p20RcyjcDqNwAuilLumnQIlMy1FqLLWcyHkTu4jich1/WQb8UjGECGpkfraEIxgYiEZJr2P1jcdShCaP2q4nkxKvFLV4+ZrnGps7jgyZE01bGcKTvMJQXZC25y57HxOZVvJw+uJUPV4cmwlssexr9pCycvHRToZmucKUAqd+aJs+w6H05VTKmysaUWneNIpevQKVRbpDrknoomX0m2PZCBoIJDiUTWpjvcLHnl1CIIgCIIgCH7Ddx7JhQsX4q677kJVVRWOP/54LFiwAP379yfHnn766fj73/9uuT506FAsX74cADB27FgsWbIk7t/LysqwYsWKlPfW0BAAGlLT5uwCnLA5R5LpTWCH2OOvUSFxrpfSUmzDDCdSXkraq5Q8bEeda67qubRXzKPPFoXO0DnlzVGv2k6+D673iPMasGNfNZxup2cnJ4/YTuhWZ+ERhXpDdVXXmR1PjtqaOr2bOsPwNM56uoJeKcoxvBPaDgRT7+fus/7vbHwlJJ988klMnToVixYtQmlpKebPn4+ysjKsX78ehYWFlvHPPfcc6urqYj//8MMPOP744/Gb3/wmbtyQIUPwyCOPxH7Ozc1V2l9dXQiRZkLb3LxGDhHiXY7KV8zKIsYxwuLckHg2o8k6u/CFWxjEODmHeuPOJsQlJxTslTA5T5BQ9p0NnasKKq54Uq9gt47RWcDi5XxOK7ybqBxiZxYROF14RM81z1MXIHZyQa3oE0JOC1W9OaTqBL0U2g4YCKR4D1Id30gqDjQA2LlzJ2688UY899xz2LFjB4444gjMnz8fQ4cOVVo/Gb4SknfffTcmTJiAcePGAQAWLVqE5cuXY/HixZg2bZplfPv27eN+XrZsGVq3bm0Rkrm5uSguLra9v3B9EAg585WDI/TIecQ7AMebyc63JO3HjyMLd5i9JTlFP6SAI72nlLik9sEQJB7wbpL2iftlx75ZgNgpYjLPVRXLUfvUa5izV24FO2GeJcZ0eub05TDaEwLJ92onx5MrQomJivN0JxR6ZR9m9Ak9+u/S2f3TXzTp/3aDdHkkU3Wg1dXV4ayzzkJhYSGeeeYZHHroofj2229RUFCQ+uJMfCMk6+rqsGbNGkyfPj12LRgMYvDgwaisrGTZePjhhzFy5EgcdNBBcddXrVqFwsJCtGvXDmeeeSZuvfVWHHLIISnvMdIABPZ3oG5gBonsFM1w5mWTHk+qNVH8filPJkWQkTxMHw3JC51zGq+T8xhhWsoWYG1pRAsx972bieY6aZ/r5eNVgPO8O057PFU9kjo9njpbKEXR56X0uxfUjH7vmhvh+ni4r0V1nA6dW68lE6pOC9lkBIIG6/PPPCdVUnWgLV68GDt27MB7772H7OxsAECXLl1SXjcVfCMkt2/fjnA4jKKiorjrRUVF+PLLL5POX716NT777DM8/PDDcdeHDBmC//u//0PXrl2xceNG3HDDDTjnnHNQWVmJUIg+SLO2tha1tbWxn2tqagAAofoIQqHo22KYEJKUuAxSikcRO7aClk9OXh9M0kvJ8J6Sf4BUb0ym4FGdx7lltBCzXuQITt2iUfVkG64XTlUIc9oE0Xu1I6icDhc75/HU3TNSZwU7jU7vrJp9W15QM7ZEiRc8kk675/Ttlf5imLod1z2SAQWP5P7b2KgZGsnNzSVT6lQcaC+++CIGDBiAq666Cv/f//f/oWPHjrj44otx/fXXJ9Q0dvGNkLTLww8/jOOOO86SVzBy5MjYfx933HHo3bs3unfvjlWrVmHQoEGkrblz52LWrFmO7lcQBEEQhMyjpKQk7ufy8nLMnDnTMk7Fgfb111/jzTffxKhRo/DKK69gw4YNuPLKK1FfX4/y8nJtz6EpvhGSHTp0QCgUQnV1ddz16urqpPmNe/bswbJly3DLLbckXadbt27o0KEDNmzYkFBITp8+HVOnTo39XFNTg5KSEoTCEWQ1JA5BU15Kc0iZi2r4285c2o1PeFkZXxUpW9xiJE44wU5RC2cMlYPJ8VzqDJNT9nSfnMMp5lG1r7O5OTXOnheO8DizKpW5nhud7pR0525y7fnIy2cntJ32HE8KZz2GnCND+ejZq97QferY6SO5ZcsW5Ofnx66rFvhSRCIRFBYW4oEHHkAoFELfvn2xdetW3HXXXSIkc3Jy0LdvX6xcuRLDhw8HEL1hK1euxKRJk5qd+/TTT6O2tha//e1vk67zn//8Bz/88AM6deqUcEwiN3So3kAomFolmUF82nHyKzliLeFcTl6jjXzLoMV7TohNYg/c0Lx5LruaXDEsaysHk1Nhrhgmp+zpbs7OK9pQs2+nlZB6YRBP/PHmJhebgI1TcmyFns3oFbjeCJ1TOHt/yLCs4oqqAjRZEUoq0K87f4hS10PbNopt8vPz44RkIlQcaJ06dUJ2dnZcGPvoo49GVVUV6urqkJOTk9qmGfhGSALA1KlTMWbMGPTr1w/9+/fH/PnzsWfPnlgS6ujRo3HooYdi7ty5cfMefvhhDB8+3FJAs3v3bsyaNQsjRoxAcXExNm7ciOuuuw49evRAWVlZyvvLaoggK5Ta2wol2CjM4jLVJN9U5zqdbxkk3sG44ti8/yzCfj1jHpDAI2l+cyB+pao5mDq9m9SaOr2b1FxusQdHEPIFoj776qLROlfnKTl2PIa8VkXWMXaKhay44fG0Y191PY1eMFfyOS3GNNqiermq79VOm6Z0kI4jElUcaCeffDKWLl2KSCSC4P4Fv/rqK3Tq1MkREQn4TEheeOGF+P7773HzzTejqqoKffr0wYoVK2L5A5s3b47duEbWr1+Pd955B6+//rrFXigUwieffIIlS5Zg586d6Ny5M84++2zMnj1bydUciBgI7H9Xby7EHQfl7qpnVGhnEcU89bxXKUdIUuKMW4nOEqqkR5IcqWaf+HQlhSQpmE0fwkzBo1ysQm1BOSda3Uums9hGtbeePY+kPlucSnR7ldbpLQyicSO0qs/jace+FWcFrlbPnw2BZT06U13oue0ZdT20naY+kqk60K644grcd999+P3vf4/Jkyfj3//+N+bMmYOrr7465bW5+EpIAsCkSZMSKvFVq1ZZrh111FEJO+C3atUKr732mra9ZdVHkLU/tN1ACETHxSXx10iJPyrvkCNCVT2GFCFCPdHilbKf/A2E8qhS++J4M7meTI7Iohqsk4Stl1RD5/yeiGqfAqq9K3XmblL2dXo3KfjeO7WqajteMp0n52Smx1PNvvP752BDQOn08rkcmvdzaDsVUnWglZSU4LXXXsOUKVPQu3dvHHroofj973+P66+/PvXFmfhOSAqCIAiCILQUUnWgDRgwAP/85z8d3tUBREhqJGgYMa9dFuFB1OqlpDyUhP1IA89LycmJVC2GoaCOeCRtsULUvMpxbl6peS43JE4V+JipI35ttnIYPZqDyUG1lyVgpyqcd41jjztPNUdSNUxOozt0683QuXqBj50wKeP9TusRhnbuq9qiXigyMuOFhuSqVduZhghJjQQbIgeqtokcRp3iMsD9FGOKS/MbDBkSZ+ZNmlsJUWFzO6LUXD3eQFTWUPmWWcSrnbM31ZA4NVf3EYmqOZiqglNnhbn6GcW0SHS6ObvOvEwzOsPk1Jp6C4P0tkLSGTqncVr0Ol3go2u99K+pVYCCchm4R7pC235AhKRGgmEDwdg7HvHnolFc0oUvvD/RAPHJYO5xye1vqdozkitK3ci3ND8nMqe0Qc0LSnkyyefoxhsOeVSIqjG1vD2nPZJO99TU6d3U2WeT2/aIRs2j5/T53vzCHfMerGPseQzTLQhbhuhNJkLtCFIdRKu2U7svqVZt+wURkhppWrVNyjyqUpkYxxGXlNeSLS4JQWsOlYcJMUt57ziFO3aEGGeu3jA5kJVtWo/yWjL7YHJaFVHUE2/cqg3JvRI6t17jFfyohsDteCTTHU7nNFi3t6bVmOpxi5R9Nzye3ih84fUE1bsP/3gknRS9Ea1HTKaOeCQPIEJSI8FIBMFIVChEqJAv0/vIEZfckLiyuCT2yhWXqu1/uNXkFnFGzKtn7otqsm4OlVNhcsqTyRGJqp5MgOfN9FPonJunyQ3xWtZz2CPptHeTQlVw6q6q9oLHU9WW7nPNVXFHCPvHI5lsTdertqHQ/sfxM9HdQYSkRoIRo4kosIoKneLSTr4lR1xyWuwA9Mk85h6Xqp5MgPbyqeZgUsdAUoLTHCpX9ZRSc1U9mQDTm+nlZG7z9omXId/7mDxXkxtGddoj6bR3k8Iq9PTat+Ivj6cVfSF37wpQIJM8ku4X24hHspEMfVqCIAiCIAiC04hHUiPxxTYU+ryUqrmVAM9LSfm+uN7NoGk29xQe0ovIKMrRX8zDCFGz8yZN+7LxF6fq8eQW+KQ7B5MKiZMeMebXXWvoObnXMrqmmufSjTC5ag6mHe8jZ64boXNueySOLf0nC5nh3ex0ezO968lMvmaCc0bShrT/OYAISY0EjOaLbWh44tIsGLgFtjqrwlXFpVlYAvQpPNSbDin+NFZVqxbzcPMtLbaJa9xzxzlhcSp8zybNb3LsE3G44WLF+AolIlQFp1dC26o5mHxBGP+zG6FzmvTmW3LteaWCXRUv5nO6fkSiwlnbmRraFiGpkaY5kvY6qRCzzU0Kqf6QzDVVq8IjxCeiqiglszSJ/pZUGyKz+ON6B3UW81Co5ltyKse5cL2uqgU+Oo+LtNNHkiJsOlbSToU5TfxAO+LASe9mdB/65ukVhGr2dXo8deZuUvb4HsT0FxCp2LaHcx7JZoN/aUA8kgcQIamR+GIbK1rFJdX9WqO4VK0mp+ZyRSl3/5z2QhROF/OohskpuE3Wzd5MVU9monGqBT4cz5zOanKA+cFJPB3qGXIEJ/cMc1XBace7qdrz0umqcKdFqTrp93iqFhDpDEe7U7jDxfseSSm2iSJCUiPBsJFyg1J1ccnwWgLK4tLpVkXkcyTWJJunm+zZ8UhSqOZgkrYUw+TcJuvWN3O9b67m/dsRqk6HzjlePq545fyR2DsRxznvZuI19cwDeKFtO/CaiFuvqYpqNzyeNDrD2E5XWrtbwe56+5+AQvsfjfmuXiJD9bEgCIIgCILgNOKRdIjmq7ebR6eX0qBcMIyZdgp3ON5NstiGsEXCOIWHOvXATqW4xZZiY3GuV5RGtU8lYUmxwIcbEtd5XKQtz6IipIeKU8Ci2BvTjTC57gIcM6oeQ8q+d0PiXLwbOle173wlt7dzJIMKxTZyRKKQnAhiVdsGM9bDOat6v+n4edw9NYSJi4SySLJeojVVWxWRELZY4pKYx2mUnghLaJtZ7c2y7UK+JQ1l36OfzMyQuPnXZCcHU1WkmAt+EtpivBT54im54PRKmNxfhTvWaxI6T9WWc/bdz5GUYptGREg6RIDyyBDjIorvomz7xDsfd67FFnFNa44kBZU3ab5nhECk96rehsiM6vnbdtrzUF5Qc84lN9+Sn+toHqPmyYzaYnhPNR4XaQfOB7+tSmtGhTnXe6F61rbOvEztFfgO28880u/xdNsj6XqOpBTbxBAhqZG49j8aBRwFLYDoPRGziStqXkTVwh1SQDM9hqxdsMUlYc3UhojTgoiLnfY89sLiuvC3J1P3+duO2mdWmHMEp07vZnSus+F0Djp7XrYc9Hkkue2RVO173SOJYCD1F3SGvvhESDoEJQTY4k8x+UNneyGyKbriPmz1z9QoLilvGqfHJacFUdS+s+KPsyY3TE5Vj3MaqtM4Ky6pVlEc8e10NblOocRuzs6sMLf8DWr1bgJeqDrXiWrbo8yE+TlF9t7UZ9/7HskA2VUk2ZxMJEMdrYIgCIIgCILTiEfSZahwN6dQR7VIB1CvAFf1Umo/5UfRS8k94tHiAWP24uRUipMeQ4+EsamTecxePa/kW+okQrj+ODmYdrxmOnMAWR49e3+EilBPyn3vrx04nkuv7NV5GJ9Tyl7L5PZdD22HgtFHqnMyEBGSaYQb7ubkUlI5hi1VXJLhAqbYZO1DsVF61L7am50bxTzpR29I3CwI2a2KVIWqYjU5oDcH040Kc1UhzD9Jxd+Ck4Of9qqKkyfnuN3+B8FA6n/IGRraFiHpU6g8StUKcMAqXu14DJ0s3GFDeBpVxaWdCnNzGyJuCyJVdJ+1zam0Vj2tR9WTGR3Hy5t0ElVPJhenczDZn2mc7yaaX9a8YwHdv69C89hpq5TM4+h6jmSI6CDCmJOJiJB0GVUvpfY+lSZ79irMnSvcSW0fJghxySnAsbMHqyjltSDSWSmembhf4GNLuJrm6hY3Xgidq1aY83G6TY03kKrzA3jquYtHMoYIyQyC66XkiEu9Feb6QuKUNTufTZy8SUpUqPbGZLcgIq6pVm1zzwX3P/HPnRaD6c/LZIlSG9XkXgids+d5pKWRFX+LzZaI2x5JhBTa/2So61uEpE/hFumohsBJTykxTr0wyCv5lsnt6TwukrNeonncAh8zZLjbhSbiTobJqX3QYtB9T6Yt+8Q1L4TOtaPz2yILvQ3bBedxO0cyEFBo/xPwyh+YXjKzhEgQBEEQBEFwHN8JyYULF6JLly7Iy8tDaWkpVq9enXDso48+Gv3W0OSRl5cXN8YwDNx8883o1KkTWrVqhcGDB+Pf//6300+jWRpPyGn60Go/bFgfjPWofQWIh9k211YwErE8OPbZ+6IejHvBtZVVH7E8QqYH/byJNcPWB/m6MD0ikYD1ESYe1DjOXO48h22Fw9aH0h7CAUQiIB7WuQ0NwfhHvfVh616b17NhP2wg7kG8VMiHeV6iR10k/kGNsWPfYov8HVkfyvbZ9ydAPNTXVdmHl+07f/+bPLR+KirQ2P4n1YcCqeiepixbtgyBQADDhw9XWpeLr0LbTz75JKZOnYpFixahtLQU8+fPR1lZGdavX4/CwkJyTn5+PtavXx/72exavvPOO3HvvfdiyZIl6Nq1K2bMmIGysjJ88cUXFtHpdVR7UlLYOZmHUwHOrTAPRrxZFa5si3N2OHhFQNSaquHvlgz75BxLOL1lhsTtQH2vTHcOpt6CH2t1sc4KcyExhsu3NF0n26joHgD45ptvcO211+LUU09Nec1U8ZWQvPvuuzFhwgSMGzcOALBo0SIsX74cixcvxrRp08g5gUAAxcXF5L8ZhoH58+fjpptuwnnnnQcAeOyxx1BUVIQXXngBI0eOTGl/kWAgJrZUq6Xdwuw5tNWnUmMFOOf+2LJPXEu7uGTa4jRPd7q/pZcx3ws7+ZY8qHm8+2req25Rp2yfeT65Z3IpVdCcVG2+F/TbpHqFudk+9fvwE9Rrh/saazrO9fuQpmIbFd0TDocxatQozJo1C//4xz+wc+fOlNdNBd8Iybq6OqxZswbTp0+PXQsGgxg8eDAqKysTztu9ezeOOOIIRCIR/PznP8ecOXNwzDHHAAA2bdqEqqoqDB48ODa+bdu2KC0tRWVlZcpCsileEYg6vZRczIJT/xnjpqpqwsWgU1w6Xbijaottj9HfEqB/T+Y2RNzzqzkFOOyG4UxBSFY+px3Ck096G/Wt6LQ3k8QFb6ajcBsbaGzOTpO86MfLBT88Ua0HPwvJmpqauMu5ubnIzc21DFfVPbfccgsKCwtx6aWX4h//+Edqe1TAN0Jy+/btCIfDKCoqirteVFSEL7/8kpxz1FFHYfHixejduzd27dqFefPmYeDAgfj8889x2GGHoaqqKmbDbLPx3yhqa2tRW1sb+9n8osgE7LQSMgsSOxXgyu2FmJ/UHPte9VpS9nR7Ys3jMtGTqQop4Jii1yo4vXFflUVpS/BkAjzBqTl0bj2CkbqJzqoqyhfgtpBze307oe2SkpK46+Xl5Zg5c6ZlvIrueeedd/Dwww9j7dq1Ke3NDr4RkioMGDAAAwYMiP08cOBAHH300fjrX/+K2bNnK9udO3cuZs2apWOLgiAIgiC0ILZs2YL8/PzYz5Q3UoUff/wRl1xyCR588EF06NBBi00OvhGSHTp0QCgUQnV1ddz16urqhDmQZrKzs3HCCSdgw4YNABCbV11djU6dOsXZ7NOnT0I706dPx9SpU2M/19TURL9hBJsPG1Ph1kyEE9p2xUtJuApU+2B61Uup8xQeam5LKeahPHPmnEu94WNevqUrYWydKDZZ90imEA/dYQcWLeOUn6a4/nFqI7Sdn58fJyQTkaru2bhxI7755hsMGzYsdi2yv2A1KysL69evR/fu3VPbMwPfCMmcnBz07dsXK1eujJWyRyIRrFy5EpMmTWLZCIfD+PTTTzF06FAAQNeuXVFcXIyVK1fGhGNNTQ3ef/99XHHFFQntJMpnSAY3N1Gn4NSZD6n7fG8OqjmeOpugs0P61D4S7C8ZXrHFCZ37XVxSZ2Z7Id8yQvQ3ofeldq91Pm9bAlejEHa8KlwVGzmYesmcU37cDm0jEEw92TmQ2vhUdU/Pnj3x6aefxl276aab8OOPP+Kee+6xhNR14RshCQBTp07FmDFj0K9fP/Tv3x/z58/Hnj17YtVMo0ePxqGHHoq5c+cCiCacnnTSSejRowd27tyJu+66C99++y3Gjx8PIFrRfc011+DWW2/FkUceGWv/07lzZ6W+S5FQoFlhxS0mcboYJt1wWwmpnjvObS+kKi7t2PdEsY1GW06fwmMHszBiF/O40QbHtFfypB72vpKLy5biyfQ9LrQvspL+HEwV3PZIBkIBsnVbsjmpkoruycvLw7HHHhs3v6CgAAAs13XiKyF54YUX4vvvv8fNN9+Mqqoq9OnTBytWrIglom7evBnBJn9l//vf/zBhwgRUVVWhXbt26Nu3L9577z306tUrNua6667Dnj17MHHiROzcuROnnHIKVqxY4UgPSVJ8uP61yh5aj2C0IS517Su6pqkqnHl0I8e+nWpyCi8IVZ2FO+o9Hb0rjNwQcWZvpk5PZtR+ctHOtqWxwCcTsVRCt+DQuad8LMFA6htSeAKp6h43CBiG2209/U9NTQ3atm2L/xvxELKzWyccZ+eEGj8JTtVwt2rLJGoe16tL7ZWzD6oqnFqTcy+4+yfb85js27LFsM+1RWEex21BRL35Kp8LbkOUWtoXMYUMxz53r9yelyz75L5Y5pXtc2xx7ZO2GPea+/ZEvRQ5c6kxdtbUap/x+3V+/7zfZbJ9/LR7L6YPnIhdu3ax8g110fh5v/3eEchvlZ3a3H316HD1s2nfs9P47ohEQRAEQRAEwRv4KrTtdZKdbMMN3ZK2TV/PvOyh5JySQ86zcSyjGW6RDquPJPl701e4Q6EzBO6F3Eo7trxazONKyNrxNbm23L//qnjluEhXMD15ykPp/EeL9bVDeSmT7cPtj8B0HZHoB0RIOoSdfD+OuGSLM7f/2myi3PCcUaQDpF9c2jka0qviknpHJ/OBGaa8Ii45gs0Lld264QvV9DZU98rpPZ6tClckTHQGIPflwRxM1z/aQsHoI9U5GYgISY0YgUBMmFAf+lqLSRQ9mQntOfhX6UbbIJ3ikntvHD8XXLGCnYvTOfxePYXHC+1/7BQZqWKuHAf0njvu9NGQnqEleDM92L6IPuEnjYSg0EfSkZ24jghJjTRt/0P9PdkRl7rmJbTnp9C56lneiuKSW7jDaS+k/Vxw0968XBWezHYq9pWbp5MTPeAa8jnqnkwg3d5MOYfcQdJ4hKTr7X8CCqHtQGa+14iQdAhuaM+r4tILnkwuOnMr2Wsqn0XOOxdcVRC6ERIn0wNUX3fENVVxGSJ+R2SlOGnLYXGjsX2Oqn3v5AUm92a64cl0PJzugUbsrqAp9OH6R4+Nk20yjUwMNAiCIAiCIAhpQDySGklata1YgAAwIwaaK8U5OJ3r6DSce2anSMeSMsA9XcfnXkoKLxQGuVHM4/tTZQjSHi720VnkXjku0lcovIm4HdpOV0NyPyBCUiORYLCJALD+ZbDFJRWSY4gbN4p5vIqdcDcnB1NVXNo6upEhLj1TFe6jwiA3TubRidNretd+5rUqYhV+tVSxCcT9EZL5zmkkXUck+gERkg5BfejrFJeUaPFqvqVX0CmqdVaAqx7dGN1H/OvM01XhaS4MonC6mIe2lZkfHv4heasiL3gyteP3/SfB9RzJYDD15N2MbFsgQlIrRjBZ+x994pLjtUy8DxGX6YJTAW6nvZDlXHCfVYVb1iOuecW+zvZFXg2T+8m+VzyZ6Q/z28DLe0sRogVmepHQdgwRkhpJ1v6HnMPM0LKEWxVD4gBPXLL7VPpcXDrtpTTDDYmrrsnxWia0r9iIXTl3kxnmVxV/5GuT8gir2meMSWRLteelYB8vtyryBBobsTuJ6x8z4pGMIUJSI02LbSi4Qo8rLjmoiktbxzm2gBxMCp1tiFTbC9k5zlHnKT9eFZdc0p2D6dVjIL2Cd0LPnDXVPJmAx72ZZlzeq4S2vUNmPitBEARBEATBccQjqZMmOZLssJeil5IOYao3EddZFc7ByyFx1ZNzVG053V5Ip5fSTj6nzsIdck3Tz6Rf346nnWGfC2uvbFstM9/SuzgbErdzpKfj+Zxp/P267pEMKORIysk2QjIigSZ9JKl/t2HbGgqzrqBTXNop3LHuy9/5lm6cnON0eyFVcamzWIgb/ubmOlrmEdecDn9TqOZbulHMIzSPVwp8fI0msen6R4OEtmOIkNRI02IbCq64ZFVtk7b0iUs7hTuc9VpqvqUdUaqzvZCquPRKbqVOcan1iEfiXivnizLXlGKe5vFVVbUU+LBx/W1fhGQMEZIaSVZsQ6EqLrmn5Oj2XHLWNONGmJzCq4LTaY+nbnFpxk5onrOeqrhke0qJazrFJWdNp5uz+63npb/En9PoK/DxPU1eB66HtqX9TwwRkhpp2kdSZxjbjj2uuOSs6NXm6VycDke7AadPpU647YVUvacUquJS1WsJ8P7e7DR6t6QkMO2zT69i2KLwSpjcnPNHhZSFprQ8T6brr4lgQMEjmZm/k8z0swqCIAiCIAiOIx5JjTQNbTtdbEOub8OWtY+kd0/h4ezB77gRZtbap5LhpdR9yg+nKtyWl1LjKT+cMLzTvSwNGz07vRICF5ojs0PiVPV6WpEcyRgiJDXSkB0AsqMvlKx6QnQRc3SeO2/HvnmuziMedZ/CY92XN6u9/Ybj7YUY54L7qipcYyN2LjrzOaWYR3A6JE61JdIVkvZGaFtyJAERklppmiPZkE0UuRAfTtQbvkF+WJu8LUyhpCoudR/xqIrOYh7SfgsQnLaOZXSwvZBqbmUiVHt2ekFc6uxvSWEnnzNAvW8p7CE6TvIt3cRctER79YjfN/ECcvo+UgVWTdd0/fcoHskYIiQ1kt4+kvps2bHHKebhVok7XcxD0VJD506j81xwtvhT9J46XhWuMwzPfV2b7oVuT6lq83eqZ4ufvJnJxE3m4r0+mK7fdxGSMURIaiSSHUS40RPpcGjbDaHKFnoWW+otiHTmYFKohs4p/CQ4VT1/3NxKN9DpPXW8KlxjviVrPRv9LTl/l043VCcFnNs5ci7hbTFr3ptz+zJcfs6BQACBFE+qSXW8X8hMeSwIgiAIgiA4jngkNRIJBhBo/PZO5EhGmDmSqr5Fr3o81XtZ0quqFvNQ2AmLm8nEPpUcvFABDjh/yo/TVeEWW8S1dDdKB7xRKa4z/O0Vj57fvay8/evNt/RUY/qAQmg7oPbXtHDhQtx1112oqqrC8ccfjwULFqB///7k2AcffBCPPfYYPvvsMwBA3759MWfOnITjdeA7j+TChQvRpUsX5OXlobS0FKtXr0449sEHH8Spp56Kdu3aoV27dhg8eLBl/NixY2Mu6sbHkCFD1DbXWMW1vw2Q+WGErA9qXDgraH1kxz8aiAe5JvGgxjUe7xh7MG2x7Jtth+h5FJFgkHgkeA5J1qQeFJznmIkEIobl4STBiGF50OMilgcFZ+/BsGF5qO6Nul/s56S4B+7viJqrsq9gmLcmZ71gxEAgbH1w9g/qQRCJBKyPsOlBjeE+zLYI4WfLvuL+2feCMS8cDlgeOu9XJALmQ23/aaMxRzLVR4o8+eSTmDp1KsrLy/HRRx/h+OOPR1lZGb777jty/KpVq3DRRRfhrbfeQmVlJUpKSnD22Wdj69atdp9xQnzlkWy8oYsWLUJpaSnmz5+PsrIyrF+/HoWFhZbxjTd04MCByMvLwx133IGzzz4bn3/+OQ499NDYuCFDhuCRRx6J/Zybm6u0v1B2BFnZ0Q85+g2GyBUkvrerVmHqrBQ3V4knmkfaN/1sp80IRbqLeSh05mBmIqq5lF5pL6TqUdVZuMP2nhL2Ld5ZG8VnnB6Xdjylqu93OivFnWxTkwpe8ZaqYqkK53oayftPjfRS1XZ62v/cfffdmDBhAsaNGwcAWLRoEZYvX47Fixdj2rRplvFPPPFE3M8PPfQQnn32WaxcuRKjR49OeX0OvhKSTt3Q3NxcFBcX295fMGik7HqPZBFvkZQNonjHMo2yn9JuvG+LsmcnTK68B8XQeUsVlulAZ3shnT01dRbuqIpLW0KPsTfdYXhLMRLjCys1L3rRvFcX2ta4ELJWFcfeFrMH9pYgMJE+0lC1XVdXhzVr1mD69OlNTAQxePBgVFZWsmzs3bsX9fX1aN++fUprp4JvhKSTN3TVqlUoLCxEu3btcOaZZ+LWW2/FIYcckvIes7KSeCSJN7mGeiKXknijC5u8jWSvSeLNNqtBzeNJ4Ya4ZFVtE/O4DdU53kw3PJkUfhehnKpqch5TsKmiKi69km/J8p4S9nUKPe7JPBSsU7sU+1uSc5ktiDwjCBW9fF4gZMPTm+z+u55PasMjWVNTE3c5NzeXjIRu374d4XAYRUVFcdeLiorw5Zdfspa8/vrr0blzZwwePDi1vaaAb4SkUzd0yJAh+L//+z907doVGzduxA033IBzzjkHlZWVCIVCpJ3a2lrU1tbGfm58UWRlGcjKir59NjC9ZI3CsynUH5pZcHKT4ql9+KnAJ6xYDGOvobp1lC5UCy8AeyLUjJ3G315AZ59KCjeKeVTnKXtPiTWVT8mxIXp1tj7ivC/a6W+pKi51evns2OI1JNe7plO4vX5USKbqkYzuuaSkJO5yeXk5Zs6cqWljB7j99tuxbNkyrFq1Cnl5edrtN+IbIWmXRDd05MiRsf8+7rjj0Lt3b3Tv3h2rVq3CoEGDSFtz587FrFmzHN+zIAiCIAiZxZYtW5Cfnx/7OVFdRocOHRAKhVBdXR13vbq6Omk63rx583D77bfjjTfeQO/eve1vuhl8IyTTdUO7deuGDh06YMOGDQmF5PTp0zF16tTYzzU1NSgpKYnLkWz0TDaFG9qmMHsuydB5A/ENmrBF+tdMuZoG+Q2UCJ0zGq+zk+KZ4WKnT/kxe0hUQ+LRcYrhXBueSw5e8D660dzcjpeSZV9j6Jm0T1xj5TBqLNxJNFcVjpfSzsk/Tp/C42cvn53wver+qXlhG2u6ho0cyfz8/DghmYicnBz07dsXK1euxPDhwwEAkUgEK1euxKRJkxLOu/POO3HbbbfhtddeQ79+/VLbowK+EZLpuqH/+c9/8MMPP6BTp04JxyTKZ8hqUrWtKhABXmibxpuV4tzzxFVzAHXmWyayZ7WvLi5V90XOdVhw+glOVTU9T613pep55YB64Q5pnxPmd6Eq3I0jHr1wCo8dlMWZR6vO7eRDJhOX3ghtO1+1PXXqVIwZMwb9+vVD//79MX/+fOzZsydWdDx69GgceuihmDt3LgDgjjvuwM0334ylS5eiS5cuqKqqAgC0adMGbdq0SXl9Dr4RkoD+G7p7927MmjULI0aMQHFxMTZu3IjrrrsOPXr0QFlZWcr7CzT1SGoViDyPJDcvk1UpzqgSB9wpwEm7LeYRjJx8SzutiihYLXV8JDZ1eintiTN97YVI+xo9hqR98zyHq8I5e0i4D04OpsYjHtliWfVeUO8Nbog6n1dtJxOXXC+mY6TprO0LL7wQ33//PW6++WZUVVWhT58+WLFiRaxeZPPmzQg2sXv//fejrq4Ov/71r+PsOJWHCfhMSOq+oaFQCJ988gmWLFmCnTt3onPnzjj77LMxe/ZspV6S2VkRZO8XfPV2PJKUSGxIbi+L8jQqVoqbq8QB9Upx3aFnneKSgvUhyRSXOsPk5D4UP4T9JC51olNc2mkvZLHt4VN40t1Tk7JPoSouVUPiAE9cku8fNsSlWRDaEXWqtrzoKY0YLgvJNJ5sM2nSpISR11WrVsX9/M033yitYQdfCUlA7w1t1aoVXnvtNU07Q1yOZDYhEKlvUKSXkni7NYtE7jwKjreUFptWOJXiZGhbsVcmtabfcjDNcPtgqrYqItdU9fK1UAFKofM4R9K+RnGpszKa2ofOnprUXDuvayfzLQGruGR7cBXFpVfD2Nx92Nm/t3Ik0xPa9gP6sswFQRAEQRCEFoXvPJJeJivJEYlc3c6p+KY8htxKcU6uJjfkzinmYXv0iHA6VSnu1WIeli3FkDjAC4vrDIlTuFEE5HR1t50TXqy23C/cie5DLfSs3EdSYxgeUA9Hc7BTbMO1x7HvBS+l7qbreo9IbH6u4XqxTXpyJP2ACEmNxL+uiNxB4q0jyHzjMOdc+qqYR2MYG/BIAQ5znE77HHFJCRkKnR/MpH2NRUBOtwRymnQX7gB6Q89eEJe2hJ7iKTw6m7OzT+FJat2az54Ir4hLVZLt3/Wq7UAw9ZxHxRxJryNCUiNZWZGYV5B6kRPOO9LjRok/c84llW9JFeTwPYsmjycjTxPgFfNwjnwE/F/Mw7HvRlsicl8aWxWpotu76XwPSrX2QrQttcIdinQf50itqfN1QeHKKTwOr+n0EY868cKJPirrO4oIyRgiJDUSanrWNiUktYo/ngSiBS0jtG2jobrFu8m8F9yG6uku5qGOaYwQ9jkfKLqFq/kDUGeYnMJpTya5pkc9km6ExCncOM6RE5r30xGPXFTD2Hbsc8L8TjdP9wreKraR0HYjIiQ10tQjyWnXAyTw8rG8mephck7/LW6YnBNi54fJ/ZOD6XSYnBZ/auHiTBSXfkL9rG194pLjtQT44tIyj7imU1x6tb8lYN2b7rZHnJN5yH2R19wXjrqqzt0PbQcUPJLu338nyEx5LAiCIAiCIDiOeCQ10rRqm4IbZuaFgnneNX71ePJwNEW68y0BXjEPlW9JhbalmOcAXC+lxZYH8i39htPHOZphF9EoVsjrPIWHi9M9Nb1aKe6VU3jc7mfJrCt0DsmRjCFCUiOhQPQBAGBWVXOEGMD7oOEW+HDC4mwhmeZ8S2pv3ObpZGg7zcU8tmwxxIDTxTwSEncGp49zpLAjLi3r+aiYx06jdFX7jhfzsHam/n7khZC4GfdD2yIkGxEhqZHcUPQBAHXEX2gkYH1DCDOr2jhJ0fx5yb2Z3CTsdOdbUnPZlekeKOax47XknMyTicU85L5aqODkwqkKp7DThoiDV4p5VPaQyj5U98URl3Ze+07nW6oW86hUd7teKCTFNjFESGokO0g6vWJQ72eUHuEKTss8RU8mNVfVk8ndFzmO6Z3lzPNqMQ95JKMLDdVVi3ns2VcLk1NI6Dw1vFIVTu9NLRytKi7Z7aQc9p7aEZeWeVz7jHPBlcPkSK/n0hNnbYtHEoAISa2E0CS0TbxeiJQSUlyyBKeN0DkFp9Kam5fJ2Rc5jhvuNoXK3cjBdPpccI4tO/ZUPRFuhMkpMs27aacnpVdx+kQip3Hae0qhWsGu8wxz3eLSOkbPa0BC294hM5+VIAiCIAiC4DjikdRIdij6AIAQ8dWMG9rmeC6pL/b1iiFxgAptq3ky91+N/0ljvmXiNU1jHC7m4Z7MY/Zcas9h1GnL6WIexXxLCtUcTNqWYkhfY0Nyv6F6VjgXJ88KV22KTtmi7OnMreRix0vJss8cp7N5ejIMt3MkxSMZQ4SkRrKDQM7+1wn1GidFo6LgpMQmBblmjnXRetOGVUPi+1eNH6Mx3zLxmqYxNvbPKeZRrRTnVIkntMVAt1D1gn0Kp3MwOWSiQHQare2FNOYw6hSXugt3OOFoCo641C1wdRbz0OMO7Nf10DYUhGSGBoFFSGokO2gge79yyyb+vZ46ZUZRcDLzxdk5mED8RcqTSfUNowSVuQKPGqOabxkdp+aRVLbFFaWMSnFOlTjgnUpxDqrFPKqeTEDdm6nTkykcwE61t9O5lJ44ztFpj6rG/pYUOot5uPYp4uy5ndssVdsxREhqJKepR5J6jVPijKg8I0PbpmGk2CTee6k2RJR981xu+yLqrYIjvPiilPIGJrfvdJicO84sEqkqcTvFPBaPJ+PIR4D2glJnijtZKc4+Ik5RXLoRJvcT/Ar59H/4KR+RyPhb1ekdBPTuVatHVbGC3ekwOYWK4CRTidJIIBBEIBBKeU4mIkJSI1lN2v/wcx+JvEYytB3/x80RmwD9Bag+TNmP/5krSoOEuKw3GVP1ZCYaZ34LC4asT4iaFyL2UU/aj8demD/5GDdyMLleUKerzs1kYg5mS8XpnpQUXhCXOs8wB3h/c27IEzf21XTNADcs5xSSIxlDhKRGmvaR5Lb6URWcHLGZyBZZqGPSYnZC52Z0ejIBqzClxaYN++bQfAPvjz/dYXIAvs7B5BT3JLLPEZc6Q+JA+nMwMxFVcamzT6Xu3FadxTyqOB0mV31O2r2/TcZxU6EcQ4RkjMx8VoIgCIIgCILjiEdSI8lyJFVb/VBzdYbJKWzlYDLyLTkhcYAfFrfMo/IVCc8iNyxutU/si7Kv6t1k5FsCRM4ls90Tt5jH/K0/i8jdVPVScnMfvVwpbkZnQ3XBXZS9py5UhVOkO98S4OWLciMAyf6SgsTnXVoRj2QMEZIaCQUOiChKdCmfYgOrcNQZJo+iloNJYd4b9b5H1ZJwe2OaZQRb1LFFqZp9CtZRfoqilBpHhcTDxFsylW/JCYs3ELmbqsc+qhbyAP4q5qEwC04vC0s7xytySHe4WzWPEvCGuLTTv1RrBbvO4yiZgtPpHNuUkKrtGCIkNZIVaP6sbTseSYs40+jdBICQSbCRXkQqB5PhpaR6ZVKwe2Oa9kp5Mim4fcfM3kxaNFJ9MJMX8zgtSpU9mYDljHFqrpfbEvmpmMeyns+Ke8x7013ZrdrwXKe4pFAVZ9w1ueLSMs9PzdNteGeb7tfOFzktiEcyhghJjeSGgNz9aog6UJ7qLUn9LXCEpOo8gP6M53kReaFzs3dT1ZNpZxyn6TqgLuzo0Hly+1zRyK1g11kYRAlO87cOTuU4APIblZNhcgqvFPNQSJg8NXT3qTSjWrgD8MQZhR1voGWew83TKbxQzOMqIiRjiJDUSG7IiAlJc4NvgBZ/lOCk3vfM47hCkhtCNs9lh57JcWbvpponkztOZ6siwBoC54eek+dlUjmZHE9mKvtgzWPkgQLWUDmnchwA+cIwj+KGyXVWnVM4noOpKDh1VpMLzuFkDqMd+zqbp1O4nYMZJNrYpRURkjEy81kJgiAIgiBkAAsXLkSXLl2Ql5eH0tJSrF69utnxTz/9NHr27Im8vDwcd9xxeOWVVxzdn3gkNdIqK4LWWVFfBserGL1GhYuTj+N6MkMB3jjzNW6xEGec2UOZyFaQ2Ct5EIzJXURF9qgv1ZSDh9y/uaE68dWX8vJR4WJOz0vq9B43vI8UZq8AN9+StGX6pXM9jewjJC2eG8LbrKliFPCG1xLwQK6YS+gsxtAZRvWKl4/cm4MnBrU40lRs8+STT2Lq1KlYtGgRSktLMX/+fJSVlWH9+vUoLCy0jH/vvfdw0UUXYe7cufjlL3+JpUuXYvjw4fjoo49w7LHHprw+h4BhuF1DnxoLFy7EXXfdhaqqKhx//PFYsGAB+vfvn3D8008/jRkzZuCbb77BkUceiTvuuANDhw6N/bthGCgvL8eDDz6InTt34uSTT8b999+PI488kr2nmpoatG3bFhX//gsOOrgVgFSEpHWcemibZ5/ah/kccDuhc/M41XnccVTonNo/Fe6mTvnhhPnJIySp/ZtCw5SoowQiPY54nmb7hC0qTM5d0yyOOWMSjmuI/5n6wKJOq6AEJ3lkm2kueVIPtSZzHCfXjmufM48cp/E0j5YiQJ2u9FUVnKrijPt8VO1znw/Hvu69Nh3XULsXlXf/Grt27UJ+fj5rvg4aP+93/fAs8vMPSnHuHrQ9ZERKey4tLcWJJ56I++67DwAQiURQUlKCyZMnY9q0aZbxF154Ifbs2YOXX345du2kk05Cnz59sGjRopT2y8VXHkknlPmdd96Je++9F0uWLEHXrl0xY8YMlJWV4YsvvkBeXl5K+8sJRZAbavRIUn8YdvImk3skae8mz36uRfypi1JrYRAlXK226L1ar5lFHNXOSKdQNYtsag8A7RmtMz2B+hyrcq2LWK9R++KIUK6oo0QpNc4sQikByl8zeasiUgiHrW9TpPfXJDgpUWpLXDosVJOtx52XaE0OOqtzdYpeO6S7QMOO905V9LKFGONecG1x9qoiEJsd12T/kdSOudZPGnIk6+rqsGbNGkyfPj12LRgMYvDgwaisrCTnVFZWYurUqXHXysrK8MILL6S21xRIWUi+9dZbOOOMM8h/++tf/4rLLrvM9qYScffdd2PChAkYN24cAGDRokVYvnw5Fi9eTCrze+65B0OGDMEf//hHAMDs2bNRUVGB++67D4sWLYJhGJg/fz5uuukmnHfeeQCAxx57DEVFRXjhhRcwcuTIlPaXEzCQEwunMRPlmbYpMWYdw7NFi9zktjh7iNpPPo+/1+T70O2dNa9JOMTYFfLmNSkByhe4lOAMm8ZQtnjpDZR31rxfrqeX85zY3k1l7ynxGiDsUx5bzj7IYyyJMH+YsFVbn/wLAHmNXJOZx2GC6xHmzLUjQJVFr0eEqio6Ba7TAlR1rqHRuwkgLmcpnF3Pm+MUNoRkTU1N3OXc3Fzk5uZahm/fvh3hcBhFRUVx14uKivDll1+SS1RVVZHjq6qqUttrCqQsJIcMGYKrr74ac+bMQXZ2tKHN9u3bMW7cOLzzzjuOCUknlPmmTZtQVVWFwYMHx/69bdu2KC0tRWVlZcpC8uCcMA4mPE7NQeUPegE3UmICGmu/AkS+pd/2YbHlgdo4w0ZmICeLhrJPzaPGmfWIzi8TANBgEmyU/QbmmmZxT4l9riffvC9qb/QYni1OGgf3ywT3i4j1d8mzz/mds78EMl/qHD3L1byZlm1g53MkmQat3b0Xa9XN28eGkCwpKYm7XF5ejpkzZ2raWPpR8kiOHj0aFRUVWLp0KTZt2oRLL70URx11FNauXevAFqM4ocwb/z9V9V5bW4va2trYz+ZvF4IgCIIgCBRbtmyJy5GkvJEA0KFDB4RCIVRXV8ddr66uRnFxMTmnuLg4pfE6SFlIDhw4EGvXrsXll1+On//854hEIpg9ezauu+46rd4XLzN37lzMmjXLcr1NdmscnN0aABBgflMJwHrPKM+T+d6SYyhbxD6CAWtyidmePVuBpGPIr/sG5ZposF4LNyQf00BcI8YZYSI8Yh7H2UPCfdSZfiY81tReyXGMa4QtI8y1xdgH9XujbFHjzCeAULF0bpUXgaVZOrvDvQsNvUOM9wfqZA+ui8c8l2o8Tbl8qKpSas2srORjKFtZxHsBuabpGnNeIMQYR+01W3FflD3zvUm0JmWL87lBjdFpy864NPY0rflxH+5O22oEAZjP3uDNAZCfn88qtsnJyUHfvn2xcuVKDB8+HEC02GblypWYNGkSOWfAgAFYuXIlrrnmmti1iooKDBgwIMXN8lEqtvnqq6/w4Ycf4rDDDsN///tfrF+/Hnv37sVBB6VWwZQKTijzxv+vrq5Gp06d4sb06dMn4V6mT58eFzKvqalBSUkJ2uYUIj8neg9I8UR9rlHiQ1U8keKm1nJJWTyZRVHCNeOvGXXEepT4oMYR1yzCiDmPFEqMvRk/Wfdq/EQINiLuaNQ2JB1DxfYitcReKfuma5Q4M+qJIixGBTg1zmDm7VH2zd8TuLa43znM9sjvJcSapC0iREqNS7aH6JrJ51HnnNsh3Yfd6NYPQdXWOyztxCxY4mosRnoSrf307UP1OXF/b5znSNnnorL/H+vdzZE0DIOVrmOekypTp07FmDFj0K9fP/Tv3x/z58/Hnj17YrUio0ePxqGHHoq5c+cCAH7/+9/jtNNOw5/+9Cece+65WLZsGT788EM88MADKa/NJWUhefvtt6O8vBwTJ07EXXfdhQ0bNuCSSy5B79698fjjjzumep1Q5l27dkVxcTFWrlwZE441NTV4//33ccUVVyTcS6LE2MAPmxGoi7b/MbgeK2ocJdgsniemR6mO561jiTP2mqa5xDxSdFGCjfJamcUTV9SR1/TZp059CYeTixtS1LGFXvw7t0GUMnLFGbWmeb+kQGSKLrPIaiAELrWvMNXXjnSCGknHkJ0TKPvU/WGIPcq5yZlnR/hR+2fNc/tkEI8TdLgqmDoJizfPxpqKAp3jQLdjn3svmo7b40YkoQkGIinnjKvkmF944YX4/vvvcfPNN6Oqqgp9+vTBihUrYil5mzdvRrDJi2LgwIFYunQpbrrpJtxwww048sgj8cILLzjWQxJQ6CPZqVMnLF68GOecc07sWn19PW644Qbce++9cbmDunnyyScxZswY/PWvf40p86eeegpffvklioqKLMr8vffew2mnnYbbb789psznzJkT1/7njjvuwO233x7X/ueTTz5Jqf1PY1+pnW9ORX6b/QKT+GTQGmK0ESJliaw6xXmApfyXPJ+Z8MLR3jTGmtx5XBFnEX88jxsl2FhCTKNnjhSgTC8cvX9ThTxTKHEEIUcMpjTOfC+YXj6nxR9H6HFFnapodPkzt1lUvbGqosVpuKJLFVUBCqiLYy+K3j3hegz7usK1PpI7dj6v1EeyfcH5ad+z06Tskfz000/RoUOHuGvZ2dm466678Mtf/lLbxiicUObXXXcd9uzZg4kTJ2Lnzp045ZRTsGLFipR7SAKA8f0PMPbkRH9g5pKRQow65cI8jvhkIG1RJyVwhBdnD8x92NkXtQ+zWKLFjfUdjBt2NNunPG6c0C21pp3QLUckcsO0tC1iTfNJGA57DLneQQrzXlUFYqI1dQpCFduA84JQd4jdSXTuVacopX5HTotLLubXJ1dYUq9PjrgkTx1j3gvq9+ulLw+GEYHByXcxzclEfHeyjRdp/IbywwMXIL/VfiFJefSYhQS08FIVesQ4johj9u7giEt6PWpfXHHDyNvjhm4ZYVmuF5H0eDJC22xbjP1z7at6FlVFIzXOTkiZI6i4nlJyboZ5Ef0kEN3AaYGiU0iKRzKK2x7J7f97Rskj2aFd+k/jcRqPfE8SBEEQBEEQ/Iavjkj0OsaP9TD2pykazKNDuF5K8zV6HpUXyDs2xbJfG/bN3kbaI8bzDnJCsFyPHvdUEE7eIblXh3MYOd5T6l7oLGrh5jBSWELbGr2P1Fyd3seofX22zNjxPoq3MTV0ex+94m202HLY+6jqbbRz/5ve65DLL/to1Xaqoe3M/FsVIamRyL76A/qLeXyCsiDkhp65x0toDUcnzzHkhmBZbWSY9lVD1FzRyKl6tiMaOaJatfAlOtc6TjWvkYKzV1XRSNmn9+C+aATUhaOIxuZpCSFqO9Xk6Q5Rc38fXskhTQVj//9SnZOJiJDUiPFTGEZjtgDXO6gqCNmHVavtg+05U/SS0d47Nc8iW5y5kNdo3hunsjuR/XQXw1DjWkoxjOQ1Cl4QjYB38xrdbiUUTLkbuF7S1f7HD4iQ1IhR2wCj8cXN/CDS6TEkIb2gxD4YbV5UPYvsCmemfU6lss5wtKqnlFrTT8Uw1LiWXAzjJy+i3xuS+xkvi0aveha59pvun9nD3DGkavsAIiQ1YtQbMILNvFDs5DVa5qkJRIAnEm3ZUmxJoxqOZoeGXQhHcyrAdYajdfZhpOCKRtU+jG6Enr0gENMt/HRjp82LKjrD2Ha8j+kOUftJNNr5HXn9y4mEtg/g8V+VIAiCIAiC4FXEI6mTiHHA66jqaQRY3kbKO0huSWP1r2peo2q1NNc+3TBczRa1N9VTYKK2zPO8UQzjhRC1zgpwCq9UQvvd25hp+Mn7yF3TL3mNQAp7TfK8g8T7fjqR0PYBREjqJByJfXqxC2sI0p3DSNnjii7OmnZEI+ckGJ3FMNSamVgM47RobKk5jC1VNHo5jK2ct9cCRKPzZ2hzxzHD9U3urbsyUkLbTREhqREjYhwQkMwcRtoOQzxpFI3UuJZaDMO1n4nFMJLXmGRuCxWJFF4VjukWjQBPONqyrzEXUWdeI2tfCgKRa891j6RUbccQIamTiBETkFxRR5pJczEMZU/n+dJ2imE4IlRnMUz0WnLvL9XkO90Nvam96RSNpH0Pi0bxLPqTlioa/VQMo+pZVBGIzY5rYs/tAo9oQ/IUPZLSkFxIhhE5IExUq6Ub7VivMULbHglHWyqVbfROVPV4qjb0BpwNR+usoI7aU+vN6HSbHa/mMIpATB2d3sd0i0bA+RC1n8PRToaenbAVajI3xHTMOIV4JA/gtqgXBEEQBEEQfIp4JDViRAIxr5qfimEAa9jajn2dxTBkON3HTb7tFMNQmL16ur2P6Q5R+z08Te2f8gyZxzl9tJ8buPGcMs37CPgrr5F1f5j3OmQjBSEdSLHNAURIaiQcDsTCuG7kMHKrknXmNXLa7NjaV5qLYQCgnsh/tM5rGcUwToaovVz4olXQcn5vzPWcFmdeaQLNEk8eEI3RcZw1vSEadQo9VVtcgZjsvrr9WpX2PwcQIakRpRxJRUGlWvjCtc8tfGH1kWS29VEV317uzWhdT100+qk3oxdyGOWs6vThdHsePxXI6D5OUDXX0WnRyJlrRzQmPWvb9art1D2MmfqOJEJSI5GGAMLBRo+k3tAzq8DExppmz6KdPoyq3lPpzZjEvsO9GZ0OPYtI9BZ2PDpuFM1wcKOq2ukCmXS32VEVjQBPOKqIRmqu2x5JKHgk2T0AfYYISY00zZHUKRoBa8Wx7pY6ToajuaKRaqnDEYk6+zBS6AxH+ylfMbqm2v1RRQRi5uB0pbVX8xq9LBqdzmHU6bGlbAWCBvnfbiA5kgcQIamRSCQQC9nqzFekxukUjdQ4TpEL176dfEWdvRl1FrqohqN19mGk7Hk59CwiMTV05kO64X1k2/exaATU8xp1tt7JdNHoRaT9zwHcdg4LgiAIgiAIPkU8khppWmyj0/sI8IphyGMHPZDX6OViGK/mNXqlpY6qB9IL3kfu3l3PtYJ+r5/Tp6aYYYdz3Whs7YGqai94HxPNNaPqfYzaj5/Lncf1PgaC9H+7gZxscwARkhqJRA6INFu9ExkVzTqLYQBnezPWN+jLV6Tm6haNOtvseLUPo58FIuB8+x8/kW7RCOittPaTaPTCaTFeFY3cuSqiMXYtYJD/7QYS2j6ACEmNRBoCiARSK7bxQjEMYPVm6sxr1F0M4+c2O6rV0lH7mZ/X6LRAdMP76IVcRz+JRq49pwtknG7y7YV+jW6IRq4n0W2hmAzxSB5AhKRG4vpI+qgYhtqbn4ph7AixdIejW7JoTLcXsaWKRu4+pKo6tXncfXhVNEbtm23pbYVkFo52RGOyua6HtsUjGUOEpEYMIxATc3ZEI6fZuJ96Mzqdw2gnnOunHEY/tdlxI/ScbuEoorHpHvTZio4z23c2h5Gzh+g1Z++F202+E82zE462jFEQjY0EPdT+J2JEH6nOcZIdO3Zg8uTJeOmllxAMBjFixAjcc889aNOmTcLx5eXleP3117F582Z07NgRw4cPx+zZs9G2bVv2uh5IMxcEQRAEQRDsMGrUKHz++eeoqKjAyy+/jLfffhsTJ05MOP6///0v/vvf/2LevHn47LPP8Oijj2LFihW49NJLU1pXPJIaiRbb7P9vjRXUgLPFMNQ4L58MI8Uw9u2r4pUil5YattbqmXPh2EGvNP62jFH0PkbtM8ZoPC1G5z3U7X3k5DWqeB+puQFn25smJWwEEE7xmMZUx6fCunXrsGLFCnzwwQfo168fAGDBggUYOnQo5s2bh86dO1vmHHvssXj22WdjP3fv3h233XYbfvvb36KhoQFZWTyJ6BuP5I4dOzBq1Cjk5+ejoKAAl156KXbv3t3s+MmTJ+Ooo45Cq1atcPjhh+Pqq6/Grl274sYFAgHLY9myZUp7jDQEYo/GfMmmj0gkYHlQ4xpD5HEPy7iA5dEoZOMeYcP6IMaFI0bcIxKG9RExrA/WON6+whFYHtz9Wx+8vZK/R9M8al/svZoe9HrWBzlO0T779cu4ry2FYChgeSjbClofnPW4ewgFrQ/rHgKWB/28mQ/CnvWR/Hkneu4qzzHhfST2YX3YuD+ce0HMowgFA5aH6j3kPO9A0CAe4D0ChuVhhpoXDBqsB71u/F7dpDG0nerDKSorK1FQUBATkQAwePBgBINBvP/++2w7u3btQn5+PltEAj7ySI4aNQrbtm1DRUUF6uvrMW7cOEycOBFLly4lxzd12fbq1QvffvstLr/8cvz3v//FM888Ezf2kUcewZAhQ2I/FxQUKO3RMBT6SLpwWkw9eRShvlxBTuGOGxXU4llsYt9HotBp76MbnkYv9Gv0ajEMYH2eLbUYJmpP3/1PZzEMkNyr2OyaHj/ZJmIEEEnRw9g4vqamJu56bm4ucnNzbe2nqqoKhYWFcdeysrLQvn17VFVVsWxs374ds2fPbjYcTuELIem0y7agoADFxcW292ns9zJG/5v4dxsFMuZzqHUWw1DjVNvuRG0lr6r2QgU1kHmi0U8C0Q1ENNq3FR3HGOPRfo1SDNPkZ2aLHUr8USKRM48e57+q7YgBpPrW3fixVFJSEne9vLwcM2fOJOdMmzYNd9xxR7N2161bl9pGCGpqanDuueeiV69eCfeSCF8IyWQu2/PPP59lJ5HL9qqrrsL48ePRrVs3XH755Rg3bhwCzSRg1NbWora2NvZz47eLpO1/iLxJL+QwUuNUz6CO7i35PFXRGLWnNk8135Lcg8NtfOys6We8cMpMInSKJzOZeDKM5DAmWdMDvRnt5CuS9s1C1Ua+ZbJ74fZ7hR2P5JYtW5Cfnx+73pw38g9/+APGjh3brN1u3bqhuLgY3333Xdz1hoYG7NixI6mj7Mcff8SQIUNw8MEH4/nnn0d2dnaSZxKPL4Skky7bW265BWeeeSZat26N119/HVdeeSV2796Nq6++OqGtuXPnYtasWak/EUEQBEEQWjT5+flxQrI5OnbsiI4dOyYdN2DAAOzcuRNr1qxB3759AQBvvvkmIpEISktLE86rqalBWVkZcnNz8eKLLyIvL4/3JJrgqpD0gst2xowZsf8+4YQTsGfPHtx1113NCsnp06dj6tSpcfZLSkoQCTc92Uaf9xFI/2kxqmdQU3N1VlBTc3U3+U53iLoleBrdwG6BjJP4+YjB6Jrx9nR6H6l92LkXfmryzfHg+Smvke/xJGwlCbu7ffKN1/pIHn300RgyZAgmTJiARYsWob6+HpMmTcLIkSNj6X9bt27FoEGD8Nhjj6F///6oqanB2Wefjb179+Lxxx9HTU1NLMLasWNHhEI8t7yrQtKLLtvS0lLMnj0btbW1Cd3NiRJjG6unAb2iMTrXPC/9p8WoCjGdohFIf16jnfC0iMT0oSoc/dSKx+kCGacbf+s8h9oLojFqL/PzGlX3yi4CUnhObudIeq39DwA88cQTmDRpEgYNGoTGhuT33ntv7N/r6+uxfv167N27FwDw0UcfxSq6e/ToEWdr06ZN6NKlC2tdV4WkF122a9euRbt27ZQqqOL6SGoUjdG5ih5JxX6NOk+LcaMYxumjAp0WiNz74zR2eg0mt+2YaUfQ2tvQA1XVXhaNTuY1UqKxpeQ1sk6esbFX872wI3rpcQb5324Q2f9IdY6TtG/fPmEnGwDo0qVL3Hnfp59+upbzv32RI+mUy/all15CdXU1TjrpJOTl5aGiogJz5szBtddeq7TPSNhAZP8fjleqqlVb77CLSRT3SqGz+Tg5zgOhZ68IRC7m/TopLFPBae+jn0VjdJzZvj7RSNnzqmgEvFsMoyqyVAtfotfUPJKqnkXdrX685JGMQKHYBt54/9SNL4Qk4IzLNjs7GwsXLsSUKVNgGAZ69OiBu+++GxMmTFDaY33YQD32C0kbXjidVdWqrXd07tWNHEZXzq/2mUjMdJzubUiv6Y28RqdPhkl3v0a/5zDqDEc72VInuh5hS3H/Os/oBuLvhdsRDa/lSLqJb4SkEy7bIUOGxDUiFwRBEARBEPj4Rkj6gUj4QA6E7mIYqy193kdqH14uhkl3iFo8jXpwuqgl3R4Kv3sf7dxXJ4thqLleKYZh5RgqVktH95F8rpO9GQF7+1ev2vZfQ3IvFtu4hQhJjUQiTXIkHS6GIde30fjbz3mNtgprPCASw0zRzuzEoJV050TqPHnGjn1u2NdqnzkuzQUyfhKNgL/zGlWrpROumcaWOtw9cOfaet7J2v/A3fduQyG0raGuxZOIkNSIkcwjqTmvkWVf4znUXhCNqdizztP3V8wVfzqh1nRDXHJw/HxsjafMqBbNeFU0cufqPC1GZzFM1B5nD7x5OvMa3ThfWmdLHeerthXzORX26rZH0s7JNpmGCEmNxHkkfVQMk2gfnH2R4zS2EuLN0/s1zw2R2BLRHbJOd6W1nXk6q6rdOGJQVVRzRaNXw9Hp7sMYXdM0xkbhjnrVtmKxjYN7dbv9T9hI/axtjSfmegoRkhppqDfQ0Iyo0ZnDSOG0aPR7OFoEYvOoiyfmOEVxoGrL6fOr/ZTXqLOCmjs3E0VjuvswAnpb6jjuUdW412TvBeKR9A4iJDUSjgDhZl4nOnMYSftM0ejVHEYRiOnDC6KRvaZH2vM4GaJ2uhhGt2jkFL9w74VXcxh1FpjoDEfb826mN8yvIhCbm5vKvzuNtP85gMuaXhAEQRAEQfAr4pHUSCRiINLYkFzjyTCJxpnRej62h/MaW6oH0p2qbW/a9rP3EXC+qlrnaTEttRjG6YbeXmip43S+pc4zxoH4/QZDkiPpFURIaiQSab5qm5yTgXmNfgpRq+7VK0cFcnB6rzqLZihbTh9F6IUCGd1HDKreCzdEo1eLYZzuw6izpY4XcxgTzbNTyd10TfdPtpEcyUZESGokKiT3eyRdaOjdUkWjF3pB6kbV+2hHNOpsqZNpojFqT1+uo84jBp0WjU4Xw6Q7r9GOaNTZh1FnSx2njzB047jIZGtKjqR3ECGpESN84IXi5YbeToejRSR6H69WWtP2vVFVzQqnuyAana6gdroYJtPC0Xyhqq+ljqpn0cnQcyLsFNs0XdMLVdupnlQjHkkhKU1zJCm8nMOYbs+iCMQD2Ml9VM3lI8d5oD2PV/IadYpGnd5T1X6NOkUjwMvba6nhaDveU6dPnrH0pHQw9Jxwno010fQ163J6kXgkDyBV24IgCIIgCIIS4pHUSDgCNOeg81Neo4SnncFP3keuPeVKZR95H6m5bhwxyOnfGLVntqWvgjo6zvSzR5p8u5HX6NXejG7nMKayJpjvIYEmr/+A6x5JKbZpRISkQ/hJNAISom4OOwUsThfNeKFAhkI1bO10VbWqaCTtO3zEoKpojNpLLnp1hjVthXOVw8XJ95VoLmcPWoWwxtC509XwunIYE0L8bbFFYdM/YO4fs0NI+58DiJDUSCRsIGLsr9p2uBjGjT6MLUE0AvaEo5PrpVs0cu05LRpV96pTNEavOXdajGrbnait5HN1VlBTc71cDOOnvMZ0n/Kju9iG41kkRSP7j9VrOZKpeiQd2ozLiJDUSMQ4IALdKIaxE45uKSLRSTjeR90i1c8FMjpFI+DvJt92mptzxIDO0Gp0TYZ9jxTDcEQvuabiPdPdUsdJz6K2wpfYXMZrlnwTYYa2m6wZcOGAhqZIsc0BREhqJBI2EGnmW6CIRu+hKuy4IWudZ1o7LRrTfVqM7r2q3h8v9GvUWUEdvWb6WWMFNWlfcw6jk+FoO1XJHPGqOzTspGdRV77igbmcNxFCgHJbg3nIIymh7QNI1bYgCIIgCIKghHgkHaKlFMNwT/Axw80v04nT3kdVdIc1ObZUT5CxlReoca9ueB+dbPztdGNonYUj0XEc+8n3lWiuk3mNOht6R9dMPoa2RVxzOkSdzhzGRPapPbCLbQL0f7tAGAoeSUd24j4iJDUSiQCRZl7bLVU0uoHTldbpDlknmsuxZefYQeW8QMXn6ZUjBlWrqFUbf+tuDO33Yhgn8xp1NvSm5nq5pU46cxgTQgpVhfY/AXeFpKGQI2lkaGhbhKRD6BaNIhIT40Z7Hi6cXDjOvITjVCuVfS4anT4tRrWKWrVfo9+KYdLdm1GnqNbZh5G07+WWOunMYUyEtP/JOERIaiQSab7Yxg3R6AWB6EYYm8KNno5eLZBRDVHb8ahyqqq90uRb5++yJRTD2Annpjsc7XRPTadb6iiJrkbSGXpOZN9G6NxLDclFSB5AhKRDtFTRmA6c7vOo8wQZ1XlOi0YvNPnW2XaHmqtbNDrZ489vDb15OZLpD0enuw8jtabOfMXoXNM4G/mK6Qw9J4TYP18ceyhHUoRkDBGSGgmHgbCGHEkOXhaNOj2QXmjPo7MxNHeezmIYnR5Vp0+GsbV/xSbfOsPRFG7kMHqhN6POhuduHAvoiXxFgHDl28hXTGfoOQG0R9J/QtKLfSR37NiByZMn46WXXkIwGMSIESNwzz33oE2bNknnGoaBoUOHYsWKFXj++ecxfPhw9rruJhkIgiAIgiAIthk1ahQ+//xzVFRU4OWXX8bbb7+NiRMnsubOnz9fuYBJPJIOobs4xsseSF14wfsYHccY4wHvY9Re8jG0LeJamk+L8WoFNaA3r5HeW3qLYexUKqsWw3glr1FnQ++0F74AFi+ircKXNOYwJoTag0rVtuRIxrFu3TqsWLECH3zwAfr16wcAWLBgAYYOHYp58+ahc+fOCeeuXbsWf/rTn/Dhhx+iU6dOKa/tG4/kjh07MGrUKOTn56OgoACXXnopdu/e3eyc008/HYFAIO5x+eWXx43ZvHkzzj33XLRu3RqFhYX44x//iIaGBqU9GhEjWnCTgoiMhHkPrxIMWR+secGA5cElFIp/2LEfDFofljGhgPXBmEfNDQVheXD3T91r6zj6OVkexHOi9sZ53vSDsX/mcwwFA5aH6nOkxgWCBvGA9REw4h/EmGDQsDxCWdZH0PTg7oHcv3lf5N549sm9Es8p6b1JcH+ofVBrmh/Kv6OAYblfduwjFLA8AtnB+EcwYHnQf1wBy8NiKzuIQCgQ96DmITtofVB7JfbG2StnXnSu4oOAZd9FIsYBMcl9NEqDmpqauEdtba3t/VRWVqKgoCAmIgFg8ODBCAaDeP/99xPO27t3Ly6++GIsXLgQxcXFSmv7xiM5atQobNu2DRUVFaivr8e4ceMwceJELF26tNl5EyZMwC233BL7uXXr1rH/DofDOPfcc1FcXIz33nsP27Ztw+jRo5GdnY05c+Zofw5eFoQc3Ki+1tnTUTXXUWeBjM4jBqPX9Owr0ThOXiO9JnEtzcUwuhtDO5nXqNpih7umrWIejUcY6vT+Uijntuo8FlByGFO3r2jLTezkSJaUlMRdLy8vx8yZM23tp6qqCoWFhXHXsrKy0L59e1RVVSWcN2XKFAwcOBDnnXee8tq+EJJ2XLatW7dOqLJff/11fPHFF3jjjTdQVFSEPn36YPbs2bj++usxc+ZM5OTkKO/Z76JRJ1zxkYkFMlr7HTIFFWtfiq13/F4M49VwtM7CFwrdolE19EzB2YdW+7qFmJ9DzxoFImnfhq2kb8TcP1yHsBPa3rJlC/Lz82PXc3NzE86ZNm0a7rjjjmbtrlu3LrWN7OfFF1/Em2++iY8//lhpfiO+EJLJXLbnn39+wrlPPPEEHn/8cRQXF2PYsGGYMWNGzCtZWVmJ4447DkVFRbHxZWVluOKKK/D555/jhBNOIG3W1tbGuaJramoA7A9De/tLFBs73keOsBDR2Pw8L4jG6NzUx0T34Z8cRlXPIl+AErYUcxh19mGkUG2z46eWOmwhpupZtOFxS2v7nAR7oNDqRWS+YScT5GzB7hB2hGR+fn6ckGyOP/zhDxg7dmyzY7p164bi4mJ89913cdcbGhqwY8eOhM60N998Exs3bkRBQUHc9REjRuDUU0/FqlWrWHv0hZBUddlefPHFOOKII9C5c2d88sknuP7667F+/Xo899xzMbtNRSSA2M/N2Z07dy5mzZql+nQEQRAEQRCS0rFjR3Ts2DHpuAEDBmDnzp1Ys2YN+vbtCyAqFCORCEpLS8k506ZNw/jx4+OuHXfccfjzn/+MYcOGsffoqpB00mULIK7s/bjjjkOnTp0waNAgbNy4Ed27d1e2O336dEydOjX2c01NjSXnQXAOJ6uqqbmq3seoreTzVCuo6fX0eR+pcTrzGrm/D680+VbNa1TNJ9SZ16izoTc119O9GdN9LKDm0LnWcLFpnPZ8RcYLSNtxi65XbQcQNlLbQ6rjU+Hoo4/GkCFDMGHCBCxatAj19fWYNGkSRo4cGUv/27p1KwYNGoTHHnsM/fv3R3FxMemtPPzww9G1a1f22q4KSSddthSNqnzDhg3o3r07iouLsXr16rgx1dXVANCs3dzc3GZzGlKBW+Ht+GkuDh4faCeMnW7RCDhbIGNHNCqH5hVFIzXO6WIYO3l7yuczeyCv0U+FL9Q+dH8pcPJYQF/lMDLXVD8lR59ABJj3Vpd49UDVttcakj/xxBOYNGkSBg0ahMaG5Pfee2/s3+vr67F+/Xrs3btX67quCkknXbYUa9euBYBYn6QBAwbgtttuw3fffRcLnVdUVCA/Px+9evVK8dkkP2vbDmbB6bSwpNBZNGNn/05WVUftOZfrqFs0cjyGFKq5jl4phmHlBTp8vjSFqmdRZ+FL1H7yMRSqnkU7Z1WzPHhezmH0cyW0ToEI8O6tDQEofSSbp3379s12sunSpQsMo/lNJPt3Cl/kSKq4bDdu3IilS5di6NChOOSQQ/DJJ59gypQp+MUvfoHevXsDAM4++2z06tULl1xyCe68805UVVXhpptuwlVXXaXN46gLncIx3W18dDfJZtmyIRo590c1RO3lYhid4ehgFiNc6UIxjB3R6GRLHTveU52i3fehZxfa56SzCXdK9tMZZm5uH+Y17XyWNb3/3IbuDhFREJJOeyTdwhdCEkjdZZuTk4M33ngD8+fPx549e1BSUoIRI0bgpptuis0JhUJ4+eWXccUVV2DAgAE46KCDMGbMmLi+ky0VO95H1Upr1VNTvCoaqX34KYeRGqczHO30+dJ2PHoUToajdYeeOXuwZd/h6mjlEKnGffm9EjqtYebGNVVFIlcUNt2HC32NmxKJRB+pzslEfCMkU3XZlpSU4O9//3tSu0cccQReeeUVLXsUBEEQBEFoSfhGSApq2MmZM8MtmrHaVpuXaK5qVbXqvdBZIKPT+0iN83KTb6+eL62zN6PO6mg73l9f5zBSc3WePMOep+atc7oS2o18xbR6Gn2CF3Mk3UKEpEDilUprnVXV5D7SXCCjUzRG95FcqGZlU0LA/ZY6buQwUrjRUkdnDiOF6rGAnshhBHjizIXTXNJ6ckvjmhqLhSy2deUr2tiD6j4CAZeLbaAgJB3ZifuIkMwgdHoftc7TKBqpfdiqVNZYIJOV5c1iGK/mNeoUjbR9vS11VI8F9EIOo872OYDeHEatp7k42SqHO9dOJXS6vYh2ClYcF6pJ5kr7H88gQtKDaK3QViya0VlprVM00nvQLFQdDEf7qRgGAEJZZs9Wywg9U+i8104eARidqxZ69n0TbuViG31eRJ2eWNK+nS8FCuvZ2ocdsef5huQS2m5EhKRPsdPCR2euIyucq/mbI6vS2qPhaDui0SzqAHXx5IVwtJdPc1E9y9sz7XMYH7JsweP3SuiW6kVUfN/VnuNpxoYAbHr//XzWdqZhw68tCIIgCIIgtGTEI+kydo7fU7XFmauzOTiFnUpr8ziqmMTp02JUcx2pverMa9R9vnTI9A7RUnMYqX04nsPI9TxxQtQ2wrmOV0J7NRxN2tK4f4pMD0fDfU+iLqQh+QFESPoU3Q3DVVv0ON2ehxpnFmNunBaTzRSEThfDWHMY1e2nO6/RjWMBdZ7wYiscbZlnI4dR52kuOiuhPXByC6C3MIi0n+ZwtKMFLIC20HNKqHwA2ekrpwEptjmACMk0otP7SKF6ygy9B30FLDpFI2lf82kxVqHqbDFMMKSxwMSGaKT3odHj6aMcRpaXUmMOo26PGKvYhkBnJbStXDsnvYjaT57R50VU79do4/XDIK0CMTY3yZpSbOMZREj6BLMIUi2YidoirjlYaa1TNFL74IpGMqys2JvRjWIY6vxqP7XUMe/fT6Hn6DiGOPNCqxzAsjedDbfZ+/CoF1GrQNS+pqpgYwrVdAtCO2Iv2ZpueyQjAUQiqT2/VMf7BRGSDmHH+6izz6POSmvVELVO0Ri1p2ZftfWOzobeHLGZin3O+dIcARq15WxLHV+Fnp1uwq2xz6MnvIg6BSLgrBdRc9Wzr/MOnfQYOrGmh4iEA4iEUxSSKY73CyIkNRIMBhBsptu+HdHI6fNo5yg8J/MadRfDmO3ZafLNKX5hizoiNKx6LKC5yAXgiTjdLXXMz8nToed0N+HWmVens1UOcc2zYeZEa5rteznM7IW8Qy96DJ22JR5Jz5AZXw0EQRAEQRCEtCMeSYfQ3TBctT2PFwpkdHofqTWpfWWTxxXy1rS0kbFRDGMOK+s+X9qad+jdljq+ymF0uBJa58ktfs5XBDxc9ayxgMVWy5t0extthbt1eimT7N/ls7bFI3kAEZIaCYaaf+1Toku1PY8b51fTZ0kzRCkxhnMuNbUvapwd0UjmDzKqi+lcR7ViGE61dKJ9cELnXPuqeYeeyWH0cyW0Cye3tNhwtOYKZ0dDzynsQ9ua6RSDKdsL0v/tAiIkDyBC0iF0ikbAKox0n19tFmeqopEaZ0c0cvo1qlZQA7QIUu3NyMlr5NvyRksdswjNxCbcmdhw2+9eREcLWHRXOHvVY+hVQZghxTZGJPViG0OEpJCMQJNiG52ikRqnUzQCVrHHta8aeqYLcNTC3TqLYai5lNeSG47mhbbV9+9k6Jmy5+nQc7orof3kRXS6d6KPCli0C0Q/eQy9IAh1CUnxSHoGEZIaCYUOvP95paqaeySf2T775BbGmjorqAH1HEbVcHSIaZ9lixFKT8W+5QsG076y4HQ69Kyzn6KdVjlp9iK6EmZuKXmHTnsM0y2o/FxpndCWwnPSHTZPERGSB8gMH7MgCIIgCIKQdsQjqZGmfSSdrqrmhJRTsW/2QOo8LcbpYhg7Tb5V8xrpcDcjTK6x8IWyz8n5pOZFB7oQjlbZQ4J9sM5/trWmg+Fo3fmKUsCS8prq9j2aI5luW0B6Q+eueySjj1TnZCIiJDUSDB54beuuqmY14SbmkSFqxt64QtXpYhjVJt/cvEbV86VVQ9Q6C2sAh3MYAaIE3+EcRjdObtEZjlYM8fr+2D7u64ILR+B4JfTsZ0HoZFW1E3ipaltOtokhQlIjwVDqHknaW0fMZfSR5IpGThW1qmgEeMUwTvdmtJPXyKraZh5FqCr06P2nN4cRcLYSWufJLaQ93VXPOj2qlnkeqXD2U8ubTPQYekEQekHg+gDJkTyACEmNBAMH/m50ikZAvck3t/WO1T4zjM3wBtoJt3LC0brPl7YeC8jdl8PHApJnfjscetbpRUxiO9E1x8PMbngRPeAxJNfUWcDSUjyGbhedaN+DB56P27YZiJA8gAhJjQRDgZgI1CkaAfUcSU4OI2Ddr85wNLu9jWI4Wnu42HJaDDGP6WVNd+g5kM0Tkl6ohHa8d6INIabsRXQ677Cl5hi6Uansp6pqr4o/Jz2SLns7w+EAgimGqsMOh7Z37NiByZMn46WXXkIwGMSIESNwzz33oE2bNs3Oq6ysxI033oj3338foVAIffr0wWuvvYZWrVqx1s1cv7MgCIIgCEILYdSoUfj8889RUVGBl19+GW+//TYmTpzY7JzKykoMGTIEZ599NlavXo0PPvgAkyZNQjAFoS4eSY2EggecFDq9j9Q4nd7H6DhTaFtjMQw3zMzNa+ScL83NwdSaI8nxsvooh5Ec56dwtMZTbBLZ49lvAaFnr+RD6pwrOZLOkCkNyQ2F0LbhnEdy3bp1WLFiBT744AP069cPALBgwQIMHToU8+bNQ+fOncl5U6ZMwdVXX41p06bFrh111FEprS1CUiOh7ACy9v9B6xSN1DWdxTCAVVDpbPLNabEDpCA4GTmY7NNoHD4W0BJq9nIOY5rD0XYaklOonoXthQIW+l5kYOjZ7/mQuudabGVYjiRBIKBHTOmyo4qhkCPZeERiTU1N3PXc3Fzk5uba2k9lZSUKCgpiIhIABg8ejGAwiPfffx/nn3++Zc53332H999/H6NGjcLAgQOxceNG9OzZE7fddhtOOeUU9tq+EZKpxv6/+eYbdO3alfy3p556Cr/5zW8A0C/Gv/3tbxg5cmTKewwGDwg8naIxajv+mh3RqLM3I0ewcXsnqnpBdedIcjySnslh5FRCe9WL6EKfRFsFLKqCkMKrHkOvCELLPI/uC8g48adVoDkpSsk35vRhp9impKQk7np5eTlmzpxpaz9VVVUoLCyMu5aVlYX27dujqqqKnPP1118DAGbOnIl58+ahT58+eOyxxzBo0CB89tlnOPLII1lr+0ZIjho1Ctu2bUNFRQXq6+sxbtw4TJw4EUuXLiXHl5SUYNu2bXHXHnjgAdx1110455xz4q4/8sgjGDJkSOzngoICpT1mZQWQtf+DRKdoBKzCUbWCGnA+HK16vrSqF1S18CXRXEulNSX+nA49Z1tzElw5/9lBLyJbILaEAhanPWK6hVimCUKfF8j4RvzptO92aNtGH8ktW7YgPz8/dr05b+S0adNwxx13NGt33bp1Ke0jtp/9HdIvu+wyjBs3DgBwwgknYOXKlVi8eDHmzp3LsuMLIakS+w+FQiguLo679vzzz+OCCy6weDELCgosY1VomiPJFY205zL5ODs5jE6HozkePTsnvOhsL8TyWlGizu+hZy94EXWLXjN2BGJL8BhKXqP99WzM9az480Ko3m3bDOx4JPPz8+OEZHP84Q9/wNixY5sd061bNxQXF+O7776Lu97Q0IAdO3Yk1DedOnUCAPTq1Svu+tFHH43Nmzez9gf4REiqxP7NrFmzBmvXrsXChQst/3bVVVdh/Pjx6NatGy6//HKMGzeu2T/y2tpa1NbWxn5uzHfIyj7gkdQpGgGrMLKTw6gajuYeC8jpw8gRswnnMkLbrJ6LAMuzSAoNwkvpidCznWKVNHsR7eVbMry/5DyNgo38luORXD4/5TWScx0Uqjb25Qnx5xehBzjrNXTbIxlJvS+kyhGJHTt2RMeOHZOOGzBgAHbu3Ik1a9agb9++AIA333wTkUgEpaWl5JwuXbqgc+fOWL9+fdz1r776yhK5bQ53fxNMVGL/Zh5++GEcffTRGDhwYNz1W265BU899RQqKiowYsQIXHnllViwYEGztubOnYu2bdvGHuZ8B0EQBEEQhHRx9NFHY8iQIZgwYQJWr16Nd999F5MmTcLIkSNjUdutW7eiZ8+eWL16NYDoF6M//vGPuPfee/HMM89gw4YNmDFjBr788ktceuml7LVd9Ug6Gftvyr59+7B06VLMmDHD8m9Nr51wwgnYs2cP7rrrLlx99dUJ7U2fPh1Tp06N/VxTU4OSkpJo1XYzHkmq7Q6nghqwevm4nkadeY3c86U5Db05BT8J1zTbY+YYKldHO5zDqDUcrbsJt5PhaKdb3riRF+jlcLeyfY3P284+NM2z5VXMNC+ibq9eOsPNPg5tO8UTTzyBSZMmYdCgQbGi5HvvvTf27/X19Vi/fj327t0bu3bNNdfgp59+wpQpU7Bjxw4cf/zxqKioQPfu3dnruioknYz9N+WZZ57B3r17MXr06KRjS0tLMXv2bNTW1iZMgE1Uqt80tM0VjTqLYai8RtWKaZ35lrYKXxSPBdRaHc3NYcxhjLOTr0jBEb1uhKO90PLG6dCzn4SYV0Sp4lxXxJ9OW14Rf34OZZvtZzU4u1YS7BTbOEX79u0TFiAD0VC2YVg/Z6dNmxbXRzJVXBWSTsb+m/Lwww/jV7/6FWuttWvXol27dko9nbJDgZiAtCMaOd5G1ebdieybvYHs4wo1thJSrY7WmcNIjvNwDiOrd6JXvIiqgqoleAxbSLW0siD0U+GOzj14aR+esB+g/9sFvNaQ3E18UWzTNPa/aNEi1NfXk7H/QYMG4bHHHkP//v1jczds2IC3334br7zyisXuSy+9hOrqapx00knIy8tDRUUF5syZg2uvvVZpn9m5QPb+OxoMWbNqueFoTsU0VyC60VKHdZqLxtAzV4B6ohJaY89Iaq4rfRLJF4HD4skNcSaCMEbaBaGf+jzqFjhe8KiStlwQRU3373qxjfdC227hCyEJqMX+AWDx4sU47LDDcPbZZ1tsZmdnY+HChZgyZQoMw0CPHj1w9913Y8KECUp7DIaMmNiyIxo5FdOq+YoJ55pT4ZghcTr0bDKmsX1OdM14e36qhLYTZtbZJ9GVvEM/ewzdWDMTw8W+6vPogvfUaXt+8m4mW9PlhuSGQmjbcDi07Ra+EZKqsf85c+Zgzpw55JwhQ4bENSIXBEEQBEEQ+PhGSPqBULaBrOyomCU9gczejJwQtZ2ekarHArJDyOZxLpzm4tlwtM6jCYm9aS9gUfXCtYTQs0cKZDybd9hSvIgt1WOYznxIlX93mEgkgICEtgGIkNRKVk4TIalRNALUyTMacxhBiD/uEYCkuDTnMCoKxATXlCuhGXtNuDfLPEVbuo/o81Po2e+hbWX7Pg8zuyLOfJSDabHlwu/SaVuk/ZadI4mIEX2kOicDESGpkVB2BKHsaJGNzsIXgMiRtJPDmE0lZnIKWHiiyyIc7eQrcvIrfVTA4iuPodP2/VaY4qQg9Mq98P2aPtqr47YcFnpuCzmXCUYMBFMUhoYISSEZoawDIlC3R1I59MwtdDHNZbfK4XgkVQUi4E4Bi5Mtb3S3t8m06uiWUpXs1dCtn0LDXgkDe1X8OS30XC52cXv9QNhAIJyaMEx1vF8QIamRUJaBUHaKVdtMwckKFzObcHNyHZUbbgOsvXom75BT0azTo+dV76Ad++Q49z2SbDEooWH7a7aUvZL2PCr+3BBa6VzTbSGp4JGMZKhHsmX7pgVBEARBEARlxCOpkVBOBFk50RxJ9nGCnD6MgCUUzM5h5Ba65DA8kqqNxTUf26dawKLcTFu1cMRpW+QYF+zbGWeixXoR/ZRPCPjbi+iVfEinPWpO23c7R9Ll9YNG6h7JINGiMBMQIamRrCTtf5Tb5wCsHEZS1JHijyHYdFY928lXpI465ISeKbj5iU7mNfqp6pk5Tqv487K4EfGUnnlAywgX+13ouZ4j6ULVeNPlIwYCKQrJVMf7BRGSGgnmAMH9R3RTldFacxjZHkmeILGIMxtnVfOKbRh7ANLvMaTmOp0jqbovznoJ8IT4c7ywxmHx1JKFnsVWBnr5dNryu6hzWzRSuLynYBgIplg8Eww7tBmXESGpkUBuFgK50VtqK/TsdCU0RyRq9CLaOp853R5Dyp7ToW3uvggsglB36NlJL5zfxJN4JO3jBaEnos4faybDbSGpUGyT6ni/IEJSI4G8IAK5UTGnmpuYcJw5R1JVIALp9yKqCkTuXN2VymkOR5PeQa+21PFIex56ro+EnhdCt14QdYnwqnjNRPtm3M595OJ2jqQIyRg+ecUIgiAIgiAIXkM8khoJtspGMC87+oPOfEXA0Ybh0X2YPJI6w9E6w9jUXDdyJAmUPYuO5016xIvYEjyGXvGkeTW/z8te0HTbd8Oj5sUQtSoe6CMpxTZRREhqJJAXQiBvfwhaZ9UzAOSEko/htsHhtOxxOvRM4ZEcSVYhis6qai+359G5pq55gDtCzwsiqKWIOidFgt9zJL2ypk5U9h90V75IaPsAIiQ1Esg5UGzD7rnI9lw66DGkrrnhMdRYCW2rKtlJQeiV9j/cNVXG2NmHG0Iv00Sdbnt+8sz53Uvplz0kIq17czlHMmwoVG2LkBSSEMjLQiAvtaptT3gMqWtOC0kKriBUFTzKQtWFVjleXbOlCD2vCrGWELLORPtcvLIPM14swHG72EYakscQIamRQKsQAq33C0muxzCLEo0Oe/kcFpKOhoYBG0LVhRxGZVs+En9+EnotQdQB/hJe6RZPbog1LwoxAIEA8fnjEwJui26FHElIaFtIRtM+kmyh5FEhyfYEupEX6Gfxp7vwxWlBle7WMlJgkj5bmWjfA4LNM+LMbaHlNNJH0jNk+CtNEARBEARBcArxSOokOxvI2d/+x5b3UZ9HUtmzyPZkpjmMzbVva5zGNZ0MDaeyD9U1vRBCTrdXNF32nLSdYZ4/8fIlwav7chK3PZJSbBNDhKROcpoIySzi1urMO9Qe2vbPsYDK9il0Cj03BI8UsGSuLQoXQreuiDgvCCMv7IGLn/aqC+kj6RlESOokLwfIy43+t07vIMAUeh5t8q07348zzis5hrrmAd7wzHnVlhP2zDgo4lqMWPOq4PHqvig8kAfqCdyu2o5EEIxEUp6TiYiQ1EggKxuB7GZC25TQUw13a27CzRJ/Tgs9p8WZn7xwXhVsPhZrQBoEm1cEiRf24YU9cMlEcean+6+C26FtDxbb7NixA5MnT8ZLL72EYDCIESNG4J577kGbNm0SzqmqqsIf//hHVFRU4Mcff8RRRx2FG2+8ESNGjGCvK0JSJ3k50QfgnfY8qiLLaSGpM1zsp9Cw07b8LsTIRVuo58wLe0iEn4WXl+9ruvHzvXBbSHowR3LUqFHYtm0bKioqUF9fj3HjxmHixIlYunRpwjmjR4/Gzp078eKLL6JDhw5YunQpLrjgAnz44Yc44YQTWOv6+FUkCIIgCIIgrFu3DitWrMBDDz2E0tJSnHLKKViwYAGWLVuG//73vwnnvffee5g8eTL69++Pbt264aabbkJBQQHWrFnDXls8kjpJkiNpy4vI8RhScGxx57lRqeynXERqG5kWSvWKB8Mr+zDjZ08d4N376jQt9Xl7mWS/E7eLbQyFYhsHT7aprKxEQUEB+vXrF7s2ePBgBINBvP/++zj//PPJeQMHDsSTTz6Jc889FwUFBXjqqafw008/4fTTT2evLUJSJzm5QO5+IckVYlyh53TeYbrFn40PXGVx5tVwtBv2VfGTUPLqPdRNS3me6Ubuq9AMdnIka2pq4q7n5uYit1E7KFJVVYXCwsK4a1lZWWjfvj2qqqoSznvqqadw4YUX4pBDDkFWVhZat26N559/Hj169GCv7Rshedttt2H58uVYu3YtcnJysHPnzqRzDMNAeXk5HnzwQezcuRMnn3wy7r//fhx55JGxMSrJqQnJOwjIa5X431UFIjVOc0saR8WZl6tznRZGfvow8tNe3UDuj6CKn76A+QXXq7bVhWRJSUnc9fLycsycOZOcM23aNNxxxx3N2l23bl1K+2jKjBkzsHPnTrzxxhvo0KEDXnjhBVxwwQX4xz/+geOOO45lwzdCsq6uDr/5zW8wYMAAPPzww6w5d955J+69914sWbIEXbt2xYwZM1BWVoYvvvgCeXl5ANSSUxOS3RrIiQpJtjDzaojXq0fEOWHPa+ulA/lgEwRBUCYYNhAMqhXbbNmyBfn5+bHrzXkj//CHP2Ds2LHN2u3WrRuKi4vx3XffxV1vaGjAjh07UFxcTM7buHEj7rvvPnz22Wc45phjAADHH388/vGPf2DhwoVYtGgR52n5R0jOmjULAPDoo4+yxhuGgfnz5+Omm27CeeedBwB47LHHUFRUhBdeeAEjR46MJad+8MEHsbyCBQsWYOjQoZg3bx46d+6c0h4DuW0QyG2d0hxPVA1TeFloZKKwU0XuRcvCyMw+dIKQMlRf5nQSUWgwvv/PNz8/P05INkfHjh3RsWPHpOMGDBiAnTt3Ys2aNejbty8A4M0330QkEkFpaSk5Z+/evQCAoOnzPhQKIZJCz8uM/RTatGkTqqqqMHjw4Ni1tm3borS0FJWVlQCSJ6cmora2FjU1NXEPQRAEQRAENzj66KMxZMgQTJgwAatXr8a7776LSZMmYeTIkTGn2NatW9GzZ0+sXr0aANCzZ0/06NEDl112GVavXo2NGzfiT3/6EyoqKjB8+HD22r7xSKZKY3JpUVFR3PWioqLYv6kmp86dOzfmIW1KTW0AqE31W5JiFVfA6aOWwg7bt4OX99YCES9Zy0N+54LTJHmN1dREvWmGg5XQzRGu24uGFD2S4YZ9Du0myhNPPIFJkyZh0KBBsZqPe++9N/bv9fX1WL9+fcwTmZ2djVdeeQXTpk3DsGHDsHv3bvTo0QNLlizB0KFD2eu6KiS5SaQ9e/ZM0454TJ8+HVOnTo39vHXrVvTq1QslR1zk4q4EQRAEoWXx448/om3btmlbLycnB8XFxXj29WuU5hcXFyMnJ0fvpvbTvn37Zus7unTpYhHeRx55JJ599llb67oqJLlJpCo0JpdWV1ejU6dOsevV1dXo06dPbEyqyamAtVS/TZs22LJlCwzDwOGHH25JpBWcp6amBiUlJXLvXUDuvTvIfXcPuffu0XjvN2/ejEAgkHItg13y8vKwadMm1NXVKc3PycmJFftmCq4KSW4SqQpdu3ZFcXExVq5cGROONTU1eP/993HFFVcAUEtOpQgGgzjssMNiuZKpJNIKepF77x5y791B7rt7yL13j7Zt27p27/Py8jJODNrBN8U2mzdvxtq1a7F582aEw2GsXbsWa9euxe7du2Njevbsieeffx4AEAgEcM011+DWW2/Fiy++iE8//RSjR49G586dY0mknORUQRAEQRAEgcY3xTY333wzlixZEvu58TDxt956K3aUz/r167Fr167YmOuuuw579uzBxIkTsXPnTpxyyilYsWJF3DeJZMmpgiAIgiAIAo1vhOSjjz6atIekOYk0EAjglltuwS233JJwTrLk1FTIzc1FeXm57aOOhNSRe+8ecu/dQe67e8i9dw+5994jYLhVOy8IgiAIgiD4Gt/kSAqCIAiCIAjeQoSkIAiCIAiCoIQISUEQBEEQBEEJEZIpsnDhQnTp0gV5eXkoLS2NnVmZiKeffho9e/ZEXl4ejjvuOLzyyitp2mnmkcq9f/DBB3HqqaeiXbt2aNeuHQYPHpz0dyUkJtXXfSPLli1DIBBI6dxW4QCp3vedO3fiqquuQqdOnZCbm4uf/exn8p6jSKr3fv78+TjqqKPQqlUrlJSUYMqUKfjpp5/StNvM4O2338awYcPQuXNnBAIBvPDCC0nnrFq1Cj//+c+Rm5uLHj16JC3KFRzAENgsW7bMyMnJMRYvXmx8/vnnxoQJE4yCggKjurqaHP/uu+8aoVDIuPPOO40vvvjCuOmmm4zs7Gzj008/TfPO/U+q9/7iiy82Fi5caHz88cfGunXrjLFjxxpt27Y1/vOf/6R55/4n1XvfyKZNm4xDDz3UOPXUU43zzjsvPZvNIFK977W1tUa/fv2MoUOHGu+8846xadMmY9WqVcbatWvTvHP/k+q9f+KJJ4zc3FzjiSeeMDZt2mS89tprRqdOnYwpU6akeef+5pVXXjFuvPFG47nnnjMAGM8//3yz47/++mujdevWxtSpU40vvvjCWLBggREKhYwVK1akZ8OCYRiGIUIyBfr3729cddVVsZ/D4bDRuXNnY+7cueT4Cy64wDj33HPjrpWWlhqXXXaZo/vMRFK992YaGhqMgw8+2FiyZIlTW8xYVO59Q0ODMXDgQOOhhx4yxowZI0JSgVTv+/33329069bNqKurS9cWM5ZU7/1VV11lnHnmmXHXpk6dapx88smO7jOT4QjJ6667zjjmmGPirl144YVGWVmZgzsTzEhom0ldXR3WrFmDwYMHx64Fg0EMHjwYlZWV5JzKysq48QBQVlaWcLxAo3Lvzezduxf19fVo3769U9vMSFTv/S233ILCwkJceuml6dhmxqFy31988UUMGDAAV111FYqKinDsscdizpw5CIfD6dp2RqBy7wcOHIg1a9bEwt9ff/01XnnlFQwdOjQte26pyGesN/BNQ3K32b59O8LhMIqKiuKuFxUV4csvvyTnVFVVkeOrqqoc22cmonLvzVx//fXo3Lmz5U1HaB6Ve//OO+/g4Ycfxtq1a9Oww8xE5b5//fXXePPNNzFq1Ci88sor2LBhA6688krU19ejvLw8HdvOCFTu/cUXX4zt27fjlFNOgWEYaGhowOWXX44bbrghHVtusST6jK2pqcG+ffvQqlUrl3bWshCPpJDx3H777Vi2bBmef/75uOMxBf38+OOPuOSSS/Dggw+iQ4cObm+nRRGJRFBYWIgHHngAffv2xYUXXogbb7wRixYtcntrGc+qVaswZ84c/OUvf8FHH32E5557DsuXL8fs2bPd3pogOI54JJl06NABoVAI1dXVcderq6tRXFxMzikuLk5pvECjcu8bmTdvHm6//Xa88cYb6N27t5PbzEhSvfcbN27EN998g2HDhsWuRSIRAEBWVhbWr1+P7t27O7vpDEDlNd+pUydkZ2cjFArFrh199NGoqqpCXV0dcnJyHN1zpqBy72fMmIFLLrkE48ePBwAcd9xx2LNnDyZOnIgbb7wRwaD4bJwg0Wdsfn6+eCPTiLy6meTk5KBv375YuXJl7FokEsHKlSsxYMAAcs6AAQPixgNARUVFwvECjcq9B4A777wTs2fPxooVK9CvX790bDXjSPXe9+zZE59++inWrl0be/zqV7/CGWecgbVr16KkpCSd2/ctKq/5k08+GRs2bIgJdwD46quv0KlTJxGRKaBy7/fu3WsRi42C3pBTiB1DPmM9gtvVPn5i2bJlRm5urvHoo48aX3zxhTFx4kSjoKDAqKqqMgzDMC655BJj2rRpsfHvvvuukZWVZcybN89Yt26dUV5eLu1/FEn13t9+++1GTk6O8cwzzxjbtm2LPX788Ue3noJvSfXem5GqbTVSve+bN282Dj74YGPSpEnG+vXrjZdfftkoLCw0br31Vreegm9J9d6Xl5cbBx98sPG3v/3N+Prrr43XX3/d6N69u3HBBRe49RR8yY8//mh8/PHHxscff2wAMO6++27j448/Nr799lvDMAxj2rRpxiWXXBIb39j+549//KOxbt06Y+HChdL+xwVESKbIggULjMMPP9zIyckx+vfvb/zzn/+M/dtpp51mjBkzJm78U089ZfzsZz8zcnJyjGOOOcZYvnx5mnecOaRy74844ggDgOVRXl6e/o1nAKm+7psiQlKdVO/7e++9Z5SWlhq5ublGt27djNtuu81oaGhI864zg1TufX19vTFz5kyje/fuRl5enlFSUmJceeWVxv/+97/0b9zHvPXWW+T7duO9HjNmjHHaaadZ5vTp08fIyckxunXrZjzyyCNp33dLJ2AY4ncXBEEQBEEQUkdyJAVBEARBEAQlREgKgiAIgiAISoiQFARBEARBEJQQISkIgiAIgiAoIUJSEARBEARBUEKEpCAIgiAIgqCECElBEARBEARBCRGSgiAIgiAIghIiJAVBEARBEAQlREgKgtBiOf3003HNNde4vQ1BEATfIkJSEARBEARBUELO2hYEoUUyduxYLFmyJO7apk2b0KVLF3c2JAiC4ENESAqC0CLZtWsXzjnnHBx77LG45ZZbAAAdO3ZEKBRyeWeCIAj+IcvtDQiCILhB27ZtkZOTg9atW6O4uNjt7QiCIPgSyZEUBEEQBEEQlBAhKQiCIAiCICghQlIQhBZLTk4OwuGw29sQBEHwLSIkBUFosXTp0gXvv/8+vvnmG2zfvh2RSMTtLQmCIPgKEZKCILRYrr32WoRCIfTq1QsdO3bE5s2b3d6SIAiCr5D2P4IgCIIgCIIS4pEUBEEQBEEQlBAhKQiCIAiCICghQlIQBEEQBEFQQoSkIAiCIAiCoIQISUEQBEEQBEEJEZKCIAiCIAiCEiIkBUEQBEEQBCVESAqCIAiCIAhKiJAUBEEQBEEQlBAhKQiCIAiCICghQlIQBEEQBEFQQoSkIAiCIAiCoMT/D/1zRfsC+Fq7AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "N_t, N_x = 100, 256\n", "\n", "t = np.linspace(0.0, 1.0, N_t)\n", "x = np.linspace(-1.0, 1.0, N_x)\n", "T, X = np.meshgrid(t, x, indexing='ij')\n", "coords = np.stack([T.flatten(), X.flatten()], axis=1)\n", "\n", "output = model(jnp.array(coords))\n", "resplot = np.array(output).reshape(N_t, N_x)\n", "\n", "plt.figure(figsize=(7, 4))\n", "plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')\n", "plt.colorbar()\n", "\n", "plt.title('Solution of Burgers Equation')\n", "plt.xlabel('t')\n", "\n", "plt.ylabel('x')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "88cc341a-8869-4dd9-be77-ef9bf2d8b5c1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }